delete local harness and remove imports
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- lm-evaluation-harness/.coveragerc +0 -28
- lm-evaluation-harness/.flake8 +0 -5
- lm-evaluation-harness/.github/workflows/new_tasks.yml +0 -71
- lm-evaluation-harness/.github/workflows/publish.yml +0 -97
- lm-evaluation-harness/.github/workflows/unit_tests.yml +0 -114
- lm-evaluation-harness/.gitignore +0 -47
- lm-evaluation-harness/.pre-commit-config.yaml +0 -60
- lm-evaluation-harness/CITATION.bib +0 -10
- lm-evaluation-harness/CODEOWNERS +0 -1
- lm-evaluation-harness/LICENSE.md +0 -21
- lm-evaluation-harness/MANIFEST.in +0 -1
- lm-evaluation-harness/README.md +0 -625
- lm-evaluation-harness/ignore.txt +0 -8
- lm-evaluation-harness/lm_eval/__init__.py +0 -7
- lm-evaluation-harness/lm_eval/__main__.py +0 -530
- lm-evaluation-harness/lm_eval/api/filter.py +0 -56
- lm-evaluation-harness/lm_eval/api/group.py +0 -115
- lm-evaluation-harness/lm_eval/api/instance.py +0 -38
- lm-evaluation-harness/lm_eval/api/metrics.py +0 -578
- lm-evaluation-harness/lm_eval/api/model.py +0 -493
- lm-evaluation-harness/lm_eval/api/registry.py +0 -196
- lm-evaluation-harness/lm_eval/api/samplers.py +0 -232
- lm-evaluation-harness/lm_eval/api/task.py +0 -1879
- lm-evaluation-harness/lm_eval/caching/cache.py +0 -59
- lm-evaluation-harness/lm_eval/decontamination/__init__.py +0 -0
- lm-evaluation-harness/lm_eval/decontamination/archiver.py +0 -174
- lm-evaluation-harness/lm_eval/decontamination/decontaminate.py +0 -166
- lm-evaluation-harness/lm_eval/decontamination/janitor.py +0 -328
- lm-evaluation-harness/lm_eval/evaluator.py +0 -761
- lm-evaluation-harness/lm_eval/evaluator_utils.py +0 -554
- lm-evaluation-harness/lm_eval/filters/__init__.py +0 -25
- lm-evaluation-harness/lm_eval/filters/custom.py +0 -17
- lm-evaluation-harness/lm_eval/filters/decontamination.py +0 -25
- lm-evaluation-harness/lm_eval/filters/extraction.py +0 -233
- lm-evaluation-harness/lm_eval/filters/selection.py +0 -61
- lm-evaluation-harness/lm_eval/filters/transformation.py +0 -122
- lm-evaluation-harness/lm_eval/loggers/__init__.py +0 -2
- lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py +0 -537
- lm-evaluation-harness/lm_eval/loggers/utils.py +0 -149
- lm-evaluation-harness/lm_eval/loggers/wandb_logger.py +0 -358
- lm-evaluation-harness/lm_eval/models/__init__.py +0 -36
- lm-evaluation-harness/lm_eval/models/anthropic_llms.py +0 -367
- lm-evaluation-harness/lm_eval/models/api_models.py +0 -799
- lm-evaluation-harness/lm_eval/models/dummy.py +0 -41
- lm-evaluation-harness/lm_eval/models/gguf.py +0 -132
- lm-evaluation-harness/lm_eval/models/hf_audiolm.py +0 -307
- lm-evaluation-harness/lm_eval/models/hf_steered.py +0 -243
- lm-evaluation-harness/lm_eval/models/hf_vlms.py +0 -757
- lm-evaluation-harness/lm_eval/models/huggingface.py +0 -1480
- lm-evaluation-harness/lm_eval/models/ibm_watsonx_ai.py +0 -445
lm-evaluation-harness/.coveragerc
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
[run]
|
| 2 |
-
|
| 3 |
-
# tasks that aren't wired up.
|
| 4 |
-
omit =
|
| 5 |
-
lm_eval/tasks/quac.py
|
| 6 |
-
lm_eval/tasks/storycloze.py
|
| 7 |
-
lm_eval/tasks/cbt.py
|
| 8 |
-
lm_eval/tasks/sat.py
|
| 9 |
-
lm_eval/tasks/triviaqa.py
|
| 10 |
-
lm_eval/tasks/naturalqs.py
|
| 11 |
-
lm_eval/models/dummy.py
|
| 12 |
-
|
| 13 |
-
[report]
|
| 14 |
-
exclude_lines =
|
| 15 |
-
# Skip any pass lines such as may be used for @abstractmethod
|
| 16 |
-
pass
|
| 17 |
-
|
| 18 |
-
# Have to re-enable the standard pragma
|
| 19 |
-
pragma: no cover
|
| 20 |
-
|
| 21 |
-
# Don't complain about missing debug-only code:
|
| 22 |
-
def __repr__
|
| 23 |
-
if self\.debug
|
| 24 |
-
|
| 25 |
-
# Don't complain if tests don't hit defensive assertion code:
|
| 26 |
-
raise AssertionError
|
| 27 |
-
raise NotImplementedError
|
| 28 |
-
return NotImplemented
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/.flake8
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
[flake8]
|
| 2 |
-
ignore = E203, E266, E501, W503, F403, F401, C901
|
| 3 |
-
max-line-length = 127
|
| 4 |
-
max-complexity = 10
|
| 5 |
-
select = B,C,E,F,W,T4,B9
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/.github/workflows/new_tasks.yml
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
name: Tasks Modified
|
| 2 |
-
|
| 3 |
-
on:
|
| 4 |
-
push:
|
| 5 |
-
branches:
|
| 6 |
-
- 'main'
|
| 7 |
-
pull_request:
|
| 8 |
-
branches:
|
| 9 |
-
- 'main'
|
| 10 |
-
workflow_dispatch:
|
| 11 |
-
# comment/edit out the above to stop/change the triggers
|
| 12 |
-
jobs:
|
| 13 |
-
changed_files:
|
| 14 |
-
runs-on: ubuntu-latest # windows-latest || macos-latest
|
| 15 |
-
timeout-minutes: 120
|
| 16 |
-
name: Scan for changed tasks
|
| 17 |
-
steps:
|
| 18 |
-
- name: checkout
|
| 19 |
-
uses: actions/checkout@v4
|
| 20 |
-
with:
|
| 21 |
-
fetch-depth: 2 # OR "2" -> To retrieve the preceding commit.
|
| 22 |
-
|
| 23 |
-
# Uses the tj-actions/changed-files action to check for changes.
|
| 24 |
-
# The `files_yaml` input optionally takes a yaml string to specify filters,
|
| 25 |
-
# and prepends the filter name to the standard output names.
|
| 26 |
-
- name: Check task folders
|
| 27 |
-
id: changed-tasks
|
| 28 |
-
uses: tj-actions/changed-files@v46.0.5
|
| 29 |
-
with:
|
| 30 |
-
# tasks checks the tasks folder and api checks the api folder for changes
|
| 31 |
-
files_yaml: |
|
| 32 |
-
tasks:
|
| 33 |
-
- lm_eval/tasks/**
|
| 34 |
-
api:
|
| 35 |
-
- lm_eval/api/**
|
| 36 |
-
write_output_files: true
|
| 37 |
-
|
| 38 |
-
# The next step is optional; the files are written to the workspace by default (above).
|
| 39 |
-
# so it's just for debugging
|
| 40 |
-
- name: Run Tests
|
| 41 |
-
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
|
| 42 |
-
run: |
|
| 43 |
-
echo .github/outputs/tasks_all_changed_and_modified_files.txt >> 'GITHUB_ENV'
|
| 44 |
-
echo "One or more test file(s) has changed."
|
| 45 |
-
echo "List of all the files that have changed: ${{ steps.changed-tasks.outputs.tasks_all_modified_files }}"
|
| 46 |
-
|
| 47 |
-
- name: Set up Python 3.9
|
| 48 |
-
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
|
| 49 |
-
uses: actions/setup-python@v5
|
| 50 |
-
with:
|
| 51 |
-
python-version: 3.9
|
| 52 |
-
cache: 'pip'
|
| 53 |
-
cache-dependency-path: setup.py
|
| 54 |
-
- name: Install dependencies
|
| 55 |
-
if: steps.changed-tasks.outputs.tasks_any_modified == 'true' || steps.changed-tasks.outputs.api_any_modified == 'true'
|
| 56 |
-
run: |
|
| 57 |
-
python -m pip install --upgrade pip
|
| 58 |
-
pip install -e '.[dev,ifeval]' --extra-index-url https://download.pytorch.org/whl/cpu
|
| 59 |
-
# Install optional git dependencies
|
| 60 |
-
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
|
| 61 |
-
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
| 62 |
-
- name: Test with pytest
|
| 63 |
-
# if new tasks are added, run tests on them
|
| 64 |
-
if: steps.changed-tasks.outputs.tasks_any_modified == 'true'
|
| 65 |
-
run: python -m pytest tests/test_tasks.py -s -vv
|
| 66 |
-
# if api is modified, run tests on it
|
| 67 |
-
- name: Test more tasks with pytest
|
| 68 |
-
env:
|
| 69 |
-
API: true
|
| 70 |
-
if: steps.changed-tasks.outputs.api_any_modified == 'true'
|
| 71 |
-
run: python -m pytest tests/test_tasks.py -s -vv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/.github/workflows/publish.yml
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
name: Publish Python distribution to PyPI
|
| 2 |
-
|
| 3 |
-
on:
|
| 4 |
-
push:
|
| 5 |
-
tags:
|
| 6 |
-
- '*'
|
| 7 |
-
|
| 8 |
-
jobs:
|
| 9 |
-
build:
|
| 10 |
-
name: Build distribution
|
| 11 |
-
runs-on: ubuntu-latest
|
| 12 |
-
|
| 13 |
-
steps:
|
| 14 |
-
- uses: actions/checkout@v4
|
| 15 |
-
- name: Set up Python
|
| 16 |
-
uses: actions/setup-python@v5
|
| 17 |
-
with:
|
| 18 |
-
python-version: "3.x"
|
| 19 |
-
|
| 20 |
-
- name: Check version consistency
|
| 21 |
-
run: |
|
| 22 |
-
# Extract version from pyproject.toml
|
| 23 |
-
PYPROJECT_VERSION=$(grep 'version = ' pyproject.toml | head -1 | cut -d'"' -f2)
|
| 24 |
-
|
| 25 |
-
# Extract version from __init__.py
|
| 26 |
-
INIT_VERSION=$(grep '__version__ = ' lm_eval/__init__.py | head -1 | cut -d'"' -f2)
|
| 27 |
-
|
| 28 |
-
echo "Version in pyproject.toml: $PYPROJECT_VERSION"
|
| 29 |
-
echo "Version in __init__.py: $INIT_VERSION"
|
| 30 |
-
|
| 31 |
-
# Check if versions match
|
| 32 |
-
if [ "$PYPROJECT_VERSION" != "$INIT_VERSION" ]; then
|
| 33 |
-
echo "Error: Version mismatch between pyproject.toml ($PYPROJECT_VERSION) and __init__.py ($INIT_VERSION)"
|
| 34 |
-
exit 1
|
| 35 |
-
fi
|
| 36 |
-
|
| 37 |
-
echo "Version check passed: $PYPROJECT_VERSION"
|
| 38 |
-
|
| 39 |
-
- name: Install pypa/build
|
| 40 |
-
run: >-
|
| 41 |
-
python3 -m
|
| 42 |
-
pip install
|
| 43 |
-
build
|
| 44 |
-
--user
|
| 45 |
-
- name: Build a binary wheel and a source tarball
|
| 46 |
-
run: python3 -m build
|
| 47 |
-
- name: Store the distribution packages
|
| 48 |
-
uses: actions/upload-artifact@v4
|
| 49 |
-
with:
|
| 50 |
-
name: python-package-distributions
|
| 51 |
-
path: dist/
|
| 52 |
-
|
| 53 |
-
publish-to-pypi:
|
| 54 |
-
name: >-
|
| 55 |
-
Publish Python distribution to PyPI
|
| 56 |
-
if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes
|
| 57 |
-
needs:
|
| 58 |
-
- build
|
| 59 |
-
runs-on: ubuntu-latest
|
| 60 |
-
environment:
|
| 61 |
-
name: pypi
|
| 62 |
-
url: https://pypi.org/p/lm_eval
|
| 63 |
-
permissions:
|
| 64 |
-
id-token: write # IMPORTANT: mandatory for trusted publishing
|
| 65 |
-
|
| 66 |
-
steps:
|
| 67 |
-
- name: Download all the dists
|
| 68 |
-
uses: actions/download-artifact@v4
|
| 69 |
-
with:
|
| 70 |
-
name: python-package-distributions
|
| 71 |
-
path: dist/
|
| 72 |
-
- name: Publish distribution to PyPI
|
| 73 |
-
uses: pypa/gh-action-pypi-publish@release/v1
|
| 74 |
-
|
| 75 |
-
publish-to-testpypi:
|
| 76 |
-
name: Publish Python distribution to TestPyPI
|
| 77 |
-
needs:
|
| 78 |
-
- build
|
| 79 |
-
runs-on: ubuntu-latest
|
| 80 |
-
|
| 81 |
-
environment:
|
| 82 |
-
name: testpypi
|
| 83 |
-
url: https://test.pypi.org/p/lm_eval
|
| 84 |
-
|
| 85 |
-
permissions:
|
| 86 |
-
id-token: write # IMPORTANT: mandatory for trusted publishing
|
| 87 |
-
|
| 88 |
-
steps:
|
| 89 |
-
- name: Download all the dists
|
| 90 |
-
uses: actions/download-artifact@v4
|
| 91 |
-
with:
|
| 92 |
-
name: python-package-distributions
|
| 93 |
-
path: dist/
|
| 94 |
-
- name: Publish distribution to TestPyPI
|
| 95 |
-
uses: pypa/gh-action-pypi-publish@release/v1
|
| 96 |
-
with:
|
| 97 |
-
repository-url: https://test.pypi.org/legacy/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/.github/workflows/unit_tests.yml
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
| 2 |
-
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
| 3 |
-
# just comment out unwanted steps to turn off the test.
|
| 4 |
-
name: Unit Tests
|
| 5 |
-
|
| 6 |
-
on:
|
| 7 |
-
push:
|
| 8 |
-
branches:
|
| 9 |
-
- 'main'
|
| 10 |
-
pull_request:
|
| 11 |
-
branches:
|
| 12 |
-
- 'main'
|
| 13 |
-
workflow_dispatch:
|
| 14 |
-
# Jobs run concurrently and steps run sequentially within a job.
|
| 15 |
-
# jobs: linter and cpu_tests. Add more jobs/steps as required.
|
| 16 |
-
jobs:
|
| 17 |
-
linter:
|
| 18 |
-
name: Linters
|
| 19 |
-
runs-on: ubuntu-latest
|
| 20 |
-
timeout-minutes: 5
|
| 21 |
-
|
| 22 |
-
steps:
|
| 23 |
-
- name: Checkout Code
|
| 24 |
-
uses: actions/checkout@v4
|
| 25 |
-
- name: Set up Python 3.9
|
| 26 |
-
uses: actions/setup-python@v5
|
| 27 |
-
with:
|
| 28 |
-
python-version: 3.9
|
| 29 |
-
cache: pip
|
| 30 |
-
cache-dependency-path: pyproject.toml
|
| 31 |
-
- name: Pre-Commit
|
| 32 |
-
env:
|
| 33 |
-
SKIP: "no-commit-to-branch,mypy"
|
| 34 |
-
uses: pre-commit/action@v3.0.1
|
| 35 |
-
# Job 2
|
| 36 |
-
testcpu:
|
| 37 |
-
name: CPU Tests
|
| 38 |
-
runs-on: ubuntu-latest
|
| 39 |
-
strategy:
|
| 40 |
-
fail-fast: true
|
| 41 |
-
matrix:
|
| 42 |
-
python-version: ["3.9", "3.10", "3.11"]
|
| 43 |
-
timeout-minutes: 30
|
| 44 |
-
steps:
|
| 45 |
-
- name: Checkout Code
|
| 46 |
-
uses: actions/checkout@v4
|
| 47 |
-
- name: Set up Python ${{ matrix.python-version }}
|
| 48 |
-
uses: actions/setup-python@v5
|
| 49 |
-
with:
|
| 50 |
-
python-version: ${{ matrix.python-version }}
|
| 51 |
-
cache: pip
|
| 52 |
-
cache-dependency-path: pyproject.toml
|
| 53 |
-
|
| 54 |
-
# Cache HuggingFace cache directory for CPU tests
|
| 55 |
-
- name: Cache HuggingFace cache (CPU tests)
|
| 56 |
-
uses: actions/cache@v3
|
| 57 |
-
id: cache-hf-cpu
|
| 58 |
-
with:
|
| 59 |
-
path: ~/.cache/huggingface
|
| 60 |
-
key: ${{ runner.os }}-hf-cache-cpu
|
| 61 |
-
restore-keys: |
|
| 62 |
-
${{ runner.os }}-hf-cache-cpu
|
| 63 |
-
|
| 64 |
-
- name: Install dependencies
|
| 65 |
-
run: |
|
| 66 |
-
python -m pip install --upgrade pip
|
| 67 |
-
pip install -e '.[dev]' --extra-index-url https://download.pytorch.org/whl/cpu
|
| 68 |
-
pip install hf_xet
|
| 69 |
-
|
| 70 |
-
- name: Test with pytest
|
| 71 |
-
run: python -m pytest --showlocals -s -vv -n=auto --ignore=tests/models/test_neuralmagic.py --ignore=tests/models/test_openvino.py --ignore=tests/models/test_hf_steered.py
|
| 72 |
-
continue-on-error: true # Continue workflow even if tests fail
|
| 73 |
-
|
| 74 |
-
# Save test artifacts
|
| 75 |
-
- name: Archive test artifacts
|
| 76 |
-
uses: actions/upload-artifact@v4
|
| 77 |
-
with:
|
| 78 |
-
name: output_testcpu${{ matrix.python-version }}
|
| 79 |
-
path: |
|
| 80 |
-
test_logs/*
|
| 81 |
-
|
| 82 |
-
# testmodels:
|
| 83 |
-
# name: External LM Tests
|
| 84 |
-
# runs-on: ubuntu-latest
|
| 85 |
-
# timeout-minutes: 30
|
| 86 |
-
# steps:
|
| 87 |
-
# - name: Checkout Code
|
| 88 |
-
# uses: actions/checkout@v4
|
| 89 |
-
# - name: Set up Python 3.9
|
| 90 |
-
# uses: actions/setup-python@v5
|
| 91 |
-
# with:
|
| 92 |
-
# python-version: 3.9
|
| 93 |
-
# cache: pip
|
| 94 |
-
# cache-dependency-path: pyproject.toml
|
| 95 |
-
#
|
| 96 |
-
# # Cache HuggingFace cache directory for External LM tests
|
| 97 |
-
# - name: Cache HuggingFace cache (External LM tests)
|
| 98 |
-
# uses: actions/cache@v3
|
| 99 |
-
# id: cache-hf-lm
|
| 100 |
-
# with:
|
| 101 |
-
# path: ~/.cache/huggingface
|
| 102 |
-
# key: ${{ runner.os }}-hf-cache-external-lm
|
| 103 |
-
# restore-keys: |
|
| 104 |
-
# ${{ runner.os }}-hf-cache-external-lm
|
| 105 |
-
#
|
| 106 |
-
# - name: Install dependencies
|
| 107 |
-
# run: |
|
| 108 |
-
# python -m pip install --upgrade pip
|
| 109 |
-
# pip install -e '.[dev,optimum,deepsparse,sparseml,api]' --extra-index-url https://download.pytorch.org/whl/cpu
|
| 110 |
-
# pip install -U transformers peft accelerate
|
| 111 |
-
#
|
| 112 |
-
# - name: Test with pytest
|
| 113 |
-
# run: python -m pytest tests/models --showlocals -s -vv
|
| 114 |
-
# continue-on-error: true # Continue workflow even if tests fail
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/.gitignore
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
# macOS system files
|
| 2 |
-
.DS_Store
|
| 3 |
-
|
| 4 |
-
# Virtual environments
|
| 5 |
-
.venv/
|
| 6 |
-
venv/
|
| 7 |
-
ENV/
|
| 8 |
-
env/
|
| 9 |
-
*.env
|
| 10 |
-
|
| 11 |
-
# Python bytecode and build artifacts
|
| 12 |
-
__pycache__/
|
| 13 |
-
*.py[cod]
|
| 14 |
-
*.so
|
| 15 |
-
*.egg-info/
|
| 16 |
-
build/
|
| 17 |
-
dist/
|
| 18 |
-
|
| 19 |
-
# IDE & editor settings
|
| 20 |
-
.vscode/
|
| 21 |
-
.idea/
|
| 22 |
-
|
| 23 |
-
# Jupyter
|
| 24 |
-
.ipynb_checkpoints/
|
| 25 |
-
profile_default/
|
| 26 |
-
ipython_config.py
|
| 27 |
-
|
| 28 |
-
# Output and data
|
| 29 |
-
output/
|
| 30 |
-
data/
|
| 31 |
-
temp/
|
| 32 |
-
test_logs/
|
| 33 |
-
|
| 34 |
-
# Caching
|
| 35 |
-
lm_eval/caching/.cache
|
| 36 |
-
lm_cache/
|
| 37 |
-
|
| 38 |
-
# Logging
|
| 39 |
-
*.log
|
| 40 |
-
logs/
|
| 41 |
-
|
| 42 |
-
# wandb experiment tracking
|
| 43 |
-
wandb/
|
| 44 |
-
examples/wandb/
|
| 45 |
-
|
| 46 |
-
# PyInstaller
|
| 47 |
-
*.spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/.pre-commit-config.yaml
DELETED
|
@@ -1,60 +0,0 @@
|
|
| 1 |
-
# Ignore test linting to avoid conflicting changes to version stability.
|
| 2 |
-
exclude: ^tests/testdata/
|
| 3 |
-
repos:
|
| 4 |
-
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 5 |
-
rev: v5.0.0
|
| 6 |
-
hooks:
|
| 7 |
-
- id: check-added-large-files
|
| 8 |
-
- id: check-ast
|
| 9 |
-
- id: check-byte-order-marker
|
| 10 |
-
- id: check-case-conflict
|
| 11 |
-
- id: check-json
|
| 12 |
-
- id: check-merge-conflict
|
| 13 |
-
args: [--assume-in-merge]
|
| 14 |
-
- id: check-symlinks
|
| 15 |
-
- id: check-yaml
|
| 16 |
-
args: ["--unsafe"]
|
| 17 |
-
- id: destroyed-symlinks
|
| 18 |
-
- id: detect-private-key
|
| 19 |
-
- id: end-of-file-fixer
|
| 20 |
-
- id: no-commit-to-branch
|
| 21 |
-
always_run: false
|
| 22 |
-
- id: requirements-txt-fixer
|
| 23 |
-
- id: trailing-whitespace
|
| 24 |
-
args: [--markdown-linebreak-ext=md]
|
| 25 |
-
- id: fix-byte-order-marker
|
| 26 |
-
exclude: docs/CNAME
|
| 27 |
-
- id: fix-encoding-pragma
|
| 28 |
-
args: [--remove]
|
| 29 |
-
- id: mixed-line-ending
|
| 30 |
-
args: [--fix=lf]
|
| 31 |
-
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 32 |
-
rev: v0.11.0
|
| 33 |
-
hooks:
|
| 34 |
-
# Run the linter.
|
| 35 |
-
- id: ruff
|
| 36 |
-
args:
|
| 37 |
-
- --fix
|
| 38 |
-
# Run the formatter.
|
| 39 |
-
- id: ruff-format
|
| 40 |
-
- repo: https://github.com/codespell-project/codespell
|
| 41 |
-
rev: v2.4.1
|
| 42 |
-
hooks:
|
| 43 |
-
- id: codespell
|
| 44 |
-
exclude: >
|
| 45 |
-
(?x)^(
|
| 46 |
-
.*\.json|ignore.txt|lm_eval/tasks/.*|.*yaml|.*\.ipynb
|
| 47 |
-
)$
|
| 48 |
-
args: [--check-filenames, --check-hidden, --ignore-words=ignore.txt]
|
| 49 |
-
- repo: https://github.com/jackdewinter/pymarkdown
|
| 50 |
-
rev: v0.9.29
|
| 51 |
-
hooks:
|
| 52 |
-
- id: pymarkdown
|
| 53 |
-
exclude: ^lm_eval/tasks/
|
| 54 |
-
args: [fix, -r]
|
| 55 |
-
# - repo: https://github.com/pre-commit/mirrors-mypy
|
| 56 |
-
# rev: v1.5.1
|
| 57 |
-
# hooks:
|
| 58 |
-
# - id: mypy
|
| 59 |
-
# additional_dependencies: [".[sentencepiece,multilingual,promptsource,gptq]", "types-PyYAML", "types-requests"]
|
| 60 |
-
# exclude: ^tests/.*$
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/CITATION.bib
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
@misc{eval-harness,
|
| 2 |
-
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
|
| 3 |
-
title = {A framework for few-shot language model evaluation},
|
| 4 |
-
month = 12,
|
| 5 |
-
year = 2023,
|
| 6 |
-
publisher = {Zenodo},
|
| 7 |
-
version = {v0.4.0},
|
| 8 |
-
doi = {10.5281/zenodo.10256836},
|
| 9 |
-
url = {https://zenodo.org/records/10256836}
|
| 10 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/CODEOWNERS
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
* @baberabb @stellaathena
|
|
|
|
|
|
lm-evaluation-harness/LICENSE.md
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
MIT License
|
| 2 |
-
|
| 3 |
-
Copyright (c) 2020 EleutherAI
|
| 4 |
-
|
| 5 |
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
-
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
-
in the Software without restriction, including without limitation the rights
|
| 8 |
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
-
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
-
furnished to do so, subject to the following conditions:
|
| 11 |
-
|
| 12 |
-
The above copyright notice and this permission notice shall be included in all
|
| 13 |
-
copies or substantial portions of the Software.
|
| 14 |
-
|
| 15 |
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
-
SOFTWARE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/MANIFEST.in
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
recursive-include tests
|
|
|
|
|
|
lm-evaluation-harness/README.md
DELETED
|
@@ -1,625 +0,0 @@
|
|
| 1 |
-
# Language Model Evaluation Harness
|
| 2 |
-
|
| 3 |
-
[](https://doi.org/10.5281/zenodo.10256836)
|
| 4 |
-
|
| 5 |
-
---
|
| 6 |
-
|
| 7 |
-
## Latest News 📣
|
| 8 |
-
|
| 9 |
-
- [2025/03] Added support for steering HF models!
|
| 10 |
-
- [2025/02] Added [SGLang](https://docs.sglang.ai/) support!
|
| 11 |
-
- [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features.
|
| 12 |
-
- [2024/07] [API model](docs/API_guide.md) support has been updated and refactored, introducing support for batched and async requests, and making it significantly easier to customize and use for your own purposes. **To run Llama 405B, we recommend using VLLM's OpenAI-compliant API to host the model, and use the `local-completions` model type to evaluate the model.**
|
| 13 |
-
- [2024/07] New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.
|
| 14 |
-
|
| 15 |
-
---
|
| 16 |
-
|
| 17 |
-
## Announcement
|
| 18 |
-
|
| 19 |
-
**A new v0.4.0 release of lm-evaluation-harness is available** !
|
| 20 |
-
|
| 21 |
-
New updates and features include:
|
| 22 |
-
|
| 23 |
-
- **New Open LLM Leaderboard tasks have been added ! You can find them under the [leaderboard](lm_eval/tasks/leaderboard/README.md) task group.**
|
| 24 |
-
- Internal refactoring
|
| 25 |
-
- Config-based task creation and configuration
|
| 26 |
-
- Easier import and sharing of externally-defined task config YAMLs
|
| 27 |
-
- Support for Jinja2 prompt design, easy modification of prompts + prompt imports from Promptsource
|
| 28 |
-
- More advanced configuration options, including output post-processing, answer extraction, and multiple LM generations per document, configurable fewshot settings, and more
|
| 29 |
-
- Speedups and new modeling libraries supported, including: faster data-parallel HF model usage, vLLM support, MPS support with HuggingFace, and more
|
| 30 |
-
- Logging and usability changes
|
| 31 |
-
- New tasks including CoT BIG-Bench-Hard, Belebele, user-defined task groupings, and more
|
| 32 |
-
|
| 33 |
-
Please see our updated documentation pages in `docs/` for more details.
|
| 34 |
-
|
| 35 |
-
Development will be continuing on the `main` branch, and we encourage you to give us feedback on what features are desired and how to improve the library further, or ask questions, either in issues or PRs on GitHub, or in the [EleutherAI discord](https://discord.gg/eleutherai)!
|
| 36 |
-
|
| 37 |
-
---
|
| 38 |
-
|
| 39 |
-
## Overview
|
| 40 |
-
|
| 41 |
-
This project provides a unified framework to test generative language models on a large number of different evaluation tasks.
|
| 42 |
-
|
| 43 |
-
**Features:**
|
| 44 |
-
|
| 45 |
-
- Over 60 standard academic benchmarks for LLMs, with hundreds of subtasks and variants implemented.
|
| 46 |
-
- Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [GPTQModel](https://github.com/ModelCloud/GPTQModel) and [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface.
|
| 47 |
-
- Support for fast and memory-efficient inference with [vLLM](https://github.com/vllm-project/vllm).
|
| 48 |
-
- Support for commercial APIs including [OpenAI](https://openai.com), and [TextSynth](https://textsynth.com/).
|
| 49 |
-
- Support for evaluation on adapters (e.g. LoRA) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
|
| 50 |
-
- Support for local models and benchmarks.
|
| 51 |
-
- Evaluation with publicly available prompts ensures reproducibility and comparability between papers.
|
| 52 |
-
- Easy support for custom prompts and evaluation metrics.
|
| 53 |
-
|
| 54 |
-
The Language Model Evaluation Harness is the backend for 🤗 Hugging Face's popular [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard), has been used in [hundreds of papers](https://scholar.google.com/scholar?oi=bibs&hl=en&authuser=2&cites=15052937328817631261,4097184744846514103,1520777361382155671,17476825572045927382,18443729326628441434,14801318227356878622,7890865700763267262,12854182577605049984,15641002901115500560,5104500764547628290), and is used internally by dozens of organizations including NVIDIA, Cohere, BigScience, BigCode, Nous Research, and Mosaic ML.
|
| 55 |
-
|
| 56 |
-
## Install
|
| 57 |
-
|
| 58 |
-
To install the `lm-eval` package from the github repository, run:
|
| 59 |
-
|
| 60 |
-
```bash
|
| 61 |
-
git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harness
|
| 62 |
-
cd lm-evaluation-harness
|
| 63 |
-
pip install -e .
|
| 64 |
-
```
|
| 65 |
-
|
| 66 |
-
We also provide a number of optional dependencies for extended functionality. A detailed table is available at the end of this document.
|
| 67 |
-
|
| 68 |
-
## Basic Usage
|
| 69 |
-
|
| 70 |
-
### User Guide
|
| 71 |
-
|
| 72 |
-
A user guide detailing the full list of supported arguments is provided [here](./docs/interface.md), and on the terminal by calling `lm_eval -h`. Alternatively, you can use `lm-eval` instead of `lm_eval`.
|
| 73 |
-
|
| 74 |
-
A list of supported tasks (or groupings of tasks) can be viewed with `lm-eval --tasks list`. Task descriptions and links to corresponding subfolders are provided [here](./lm_eval/tasks/README.md).
|
| 75 |
-
|
| 76 |
-
### Hugging Face `transformers`
|
| 77 |
-
|
| 78 |
-
To evaluate a model hosted on the [HuggingFace Hub](https://huggingface.co/models) (e.g. GPT-J-6B) on `hellaswag` you can use the following command (this assumes you are using a CUDA-compatible GPU):
|
| 79 |
-
|
| 80 |
-
```bash
|
| 81 |
-
lm_eval --model hf \
|
| 82 |
-
--model_args pretrained=EleutherAI/gpt-j-6B \
|
| 83 |
-
--tasks hellaswag \
|
| 84 |
-
--device cuda:0 \
|
| 85 |
-
--batch_size 8
|
| 86 |
-
```
|
| 87 |
-
|
| 88 |
-
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints, or to specify the datatype for running a model:
|
| 89 |
-
|
| 90 |
-
```bash
|
| 91 |
-
lm_eval --model hf \
|
| 92 |
-
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
|
| 93 |
-
--tasks lambada_openai,hellaswag \
|
| 94 |
-
--device cuda:0 \
|
| 95 |
-
--batch_size 8
|
| 96 |
-
```
|
| 97 |
-
|
| 98 |
-
Models that are loaded via both `transformers.AutoModelForCausalLM` (autoregressive, decoder-only GPT style models) and `transformers.AutoModelForSeq2SeqLM` (such as encoder-decoder models like T5) in Huggingface are supported.
|
| 99 |
-
|
| 100 |
-
Batch size selection can be automated by setting the ```--batch_size``` flag to ```auto```. This will perform automatic detection of the largest batch size that will fit on your device. On tasks where there is a large difference between the longest and shortest example, it can be helpful to periodically recompute the largest batch size, to gain a further speedup. To do this, append ```:N``` to above flag to automatically recompute the largest batch size ```N``` times. For example, to recompute the batch size 4 times, the command would be:
|
| 101 |
-
|
| 102 |
-
```bash
|
| 103 |
-
lm_eval --model hf \
|
| 104 |
-
--model_args pretrained=EleutherAI/pythia-160m,revision=step100000,dtype="float" \
|
| 105 |
-
--tasks lambada_openai,hellaswag \
|
| 106 |
-
--device cuda:0 \
|
| 107 |
-
--batch_size auto:4
|
| 108 |
-
```
|
| 109 |
-
|
| 110 |
-
> [!Note]
|
| 111 |
-
> Just like you can provide a local path to `transformers.AutoModel`, you can also provide a local path to `lm_eval` via `--model_args pretrained=/path/to/model`
|
| 112 |
-
|
| 113 |
-
#### Multi-GPU Evaluation with Hugging Face `accelerate`
|
| 114 |
-
|
| 115 |
-
We support three main ways of using Hugging Face's [accelerate 🚀](https://github.com/huggingface/accelerate) library for multi-GPU evaluation.
|
| 116 |
-
|
| 117 |
-
To perform *data-parallel evaluation* (where each GPU loads a **separate full copy** of the model), we leverage the `accelerate` launcher as follows:
|
| 118 |
-
|
| 119 |
-
```bash
|
| 120 |
-
accelerate launch -m lm_eval --model hf \
|
| 121 |
-
--tasks lambada_openai,arc_easy \
|
| 122 |
-
--batch_size 16
|
| 123 |
-
```
|
| 124 |
-
|
| 125 |
-
(or via `accelerate launch --no-python lm_eval`).
|
| 126 |
-
|
| 127 |
-
For cases where your model can fit on a single GPU, this allows you to evaluate on K GPUs K times faster than on one.
|
| 128 |
-
|
| 129 |
-
**WARNING**: This setup does not work with FSDP model sharding, so in `accelerate config` FSDP must be disabled, or the NO_SHARD FSDP option must be used.
|
| 130 |
-
|
| 131 |
-
The second way of using `accelerate` for multi-GPU evaluation is when your model is *too large to fit on a single GPU.*
|
| 132 |
-
|
| 133 |
-
In this setting, run the library *outside the `accelerate` launcher*, but passing `parallelize=True` to `--model_args` as follows:
|
| 134 |
-
|
| 135 |
-
```bash
|
| 136 |
-
lm_eval --model hf \
|
| 137 |
-
--tasks lambada_openai,arc_easy \
|
| 138 |
-
--model_args parallelize=True \
|
| 139 |
-
--batch_size 16
|
| 140 |
-
```
|
| 141 |
-
|
| 142 |
-
This means that your model's weights will be split across all available GPUs.
|
| 143 |
-
|
| 144 |
-
For more advanced users or even larger models, we allow for the following arguments when `parallelize=True` as well:
|
| 145 |
-
|
| 146 |
-
- `device_map_option`: How to split model weights across available GPUs. defaults to "auto".
|
| 147 |
-
- `max_memory_per_gpu`: the max GPU memory to use per GPU in loading the model.
|
| 148 |
-
- `max_cpu_memory`: the max amount of CPU memory to use when offloading the model weights to RAM.
|
| 149 |
-
- `offload_folder`: a folder where model weights will be offloaded to disk if needed.
|
| 150 |
-
|
| 151 |
-
The third option is to use both at the same time. This will allow you to take advantage of both data parallelism and model sharding, and is especially useful for models that are too large to fit on a single GPU.
|
| 152 |
-
|
| 153 |
-
```bash
|
| 154 |
-
accelerate launch --multi_gpu --num_processes {nb_of_copies_of_your_model} \
|
| 155 |
-
-m lm_eval --model hf \
|
| 156 |
-
--tasks lambada_openai,arc_easy \
|
| 157 |
-
--model_args parallelize=True \
|
| 158 |
-
--batch_size 16
|
| 159 |
-
```
|
| 160 |
-
|
| 161 |
-
To learn more about model parallelism and how to use it with the `accelerate` library, see the [accelerate documentation](https://huggingface.co/docs/transformers/v4.15.0/en/parallelism)
|
| 162 |
-
|
| 163 |
-
**Warning: We do not natively support multi-node evaluation using the `hf` model type! Please reference [our GPT-NeoX library integration](https://github.com/EleutherAI/gpt-neox/blob/main/eval.py) for an example of code in which a custom multi-machine evaluation script is written.**
|
| 164 |
-
|
| 165 |
-
**Note: we do not currently support multi-node evaluations natively, and advise using either an externally hosted server to run inference requests against, or creating a custom integration with your distributed framework [as is done for the GPT-NeoX library](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py).**
|
| 166 |
-
|
| 167 |
-
### Steered Hugging Face `transformers` models
|
| 168 |
-
|
| 169 |
-
To evaluate a Hugging Face `transformers` model with steering vectors applied, specify the model type as `steered` and provide the path to either a PyTorch file containing pre-defined steering vectors, or a CSV file that specifies how to derive steering vectors from pretrained `sparsify` or `sae_lens` models (you will need to install the corresponding optional dependency for this method).
|
| 170 |
-
|
| 171 |
-
Specify pre-defined steering vectors:
|
| 172 |
-
|
| 173 |
-
```python
|
| 174 |
-
import torch
|
| 175 |
-
|
| 176 |
-
steer_config = {
|
| 177 |
-
"layers.3": {
|
| 178 |
-
"steering_vector": torch.randn(1, 768),
|
| 179 |
-
"bias": torch.randn(1, 768),
|
| 180 |
-
"steering_coefficient": 1,
|
| 181 |
-
"action": "add"
|
| 182 |
-
},
|
| 183 |
-
}
|
| 184 |
-
torch.save(steer_config, "steer_config.pt")
|
| 185 |
-
```
|
| 186 |
-
|
| 187 |
-
Specify derived steering vectors:
|
| 188 |
-
|
| 189 |
-
```python
|
| 190 |
-
import pandas as pd
|
| 191 |
-
|
| 192 |
-
pd.DataFrame({
|
| 193 |
-
"loader": ["sparsify"],
|
| 194 |
-
"action": ["add"],
|
| 195 |
-
"sparse_model": ["EleutherAI/sae-pythia-70m-32k"],
|
| 196 |
-
"hookpoint": ["layers.3"],
|
| 197 |
-
"feature_index": [30],
|
| 198 |
-
"steering_coefficient": [10.0],
|
| 199 |
-
}).to_csv("steer_config.csv", index=False)
|
| 200 |
-
```
|
| 201 |
-
|
| 202 |
-
Run the evaluation harness with steering vectors applied:
|
| 203 |
-
|
| 204 |
-
```bash
|
| 205 |
-
lm_eval --model steered \
|
| 206 |
-
--model_args pretrained=EleutherAI/pythia-160m,steer_path=steer_config.pt \
|
| 207 |
-
--tasks lambada_openai,hellaswag \
|
| 208 |
-
--device cuda:0 \
|
| 209 |
-
--batch_size 8
|
| 210 |
-
```
|
| 211 |
-
|
| 212 |
-
### NVIDIA `nemo` models
|
| 213 |
-
|
| 214 |
-
[NVIDIA NeMo Framework](https://github.com/NVIDIA/NeMo) is a generative AI framework built for researchers and pytorch developers working on language models.
|
| 215 |
-
|
| 216 |
-
To evaluate a `nemo` model, start by installing NeMo following [the documentation](https://github.com/NVIDIA/NeMo?tab=readme-ov-file#installation). We highly recommended to use the NVIDIA PyTorch or NeMo container, especially if having issues installing Apex or any other dependencies (see [latest released containers](https://github.com/NVIDIA/NeMo/releases)). Please also install the lm evaluation harness library following the instructions in [the Install section](https://github.com/EleutherAI/lm-evaluation-harness/tree/main?tab=readme-ov-file#install).
|
| 217 |
-
|
| 218 |
-
NeMo models can be obtained through [NVIDIA NGC Catalog](https://catalog.ngc.nvidia.com/models) or in [NVIDIA's Hugging Face page](https://huggingface.co/nvidia). In [NVIDIA NeMo Framework](https://github.com/NVIDIA/NeMo/tree/main/scripts/nlp_language_modeling) there are conversion scripts to convert the `hf` checkpoints of popular models like llama, falcon, mixtral or mpt to `nemo`.
|
| 219 |
-
|
| 220 |
-
Run a `nemo` model on one GPU:
|
| 221 |
-
|
| 222 |
-
```bash
|
| 223 |
-
lm_eval --model nemo_lm \
|
| 224 |
-
--model_args path=<path_to_nemo_model> \
|
| 225 |
-
--tasks hellaswag \
|
| 226 |
-
--batch_size 32
|
| 227 |
-
```
|
| 228 |
-
|
| 229 |
-
It is recommended to unpack the `nemo` model to avoid the unpacking inside the docker container - it may overflow disk space. For that you can run:
|
| 230 |
-
|
| 231 |
-
```bash
|
| 232 |
-
mkdir MY_MODEL
|
| 233 |
-
tar -xvf MY_MODEL.nemo -c MY_MODEL
|
| 234 |
-
```
|
| 235 |
-
|
| 236 |
-
#### Multi-GPU evaluation with NVIDIA `nemo` models
|
| 237 |
-
|
| 238 |
-
By default, only one GPU is used. But we do support either data replication or tensor/pipeline parallelism during evaluation, on one node.
|
| 239 |
-
|
| 240 |
-
1) To enable data replication, set the `model_args` of `devices` to the number of data replicas to run. For example, the command to run 8 data replicas over 8 GPUs is:
|
| 241 |
-
|
| 242 |
-
```bash
|
| 243 |
-
torchrun --nproc-per-node=8 --no-python lm_eval \
|
| 244 |
-
--model nemo_lm \
|
| 245 |
-
--model_args path=<path_to_nemo_model>,devices=8 \
|
| 246 |
-
--tasks hellaswag \
|
| 247 |
-
--batch_size 32
|
| 248 |
-
```
|
| 249 |
-
|
| 250 |
-
1) To enable tensor and/or pipeline parallelism, set the `model_args` of `tensor_model_parallel_size` and/or `pipeline_model_parallel_size`. In addition, you also have to set up `devices` to be equal to the product of `tensor_model_parallel_size` and/or `pipeline_model_parallel_size`. For example, the command to use one node of 4 GPUs with tensor parallelism of 2 and pipeline parallelism of 2 is:
|
| 251 |
-
|
| 252 |
-
```bash
|
| 253 |
-
torchrun --nproc-per-node=4 --no-python lm_eval \
|
| 254 |
-
--model nemo_lm \
|
| 255 |
-
--model_args path=<path_to_nemo_model>,devices=4,tensor_model_parallel_size=2,pipeline_model_parallel_size=2 \
|
| 256 |
-
--tasks hellaswag \
|
| 257 |
-
--batch_size 32
|
| 258 |
-
```
|
| 259 |
-
|
| 260 |
-
Note that it is recommended to substitute the `python` command by `torchrun --nproc-per-node=<number of devices> --no-python` to facilitate loading the model into the GPUs. This is especially important for large checkpoints loaded into multiple GPUs.
|
| 261 |
-
|
| 262 |
-
Not supported yet: multi-node evaluation and combinations of data replication with tensor or pipeline parallelism.
|
| 263 |
-
|
| 264 |
-
#### Multi-GPU evaluation with OpenVINO models
|
| 265 |
-
|
| 266 |
-
Pipeline parallelism during evaluation is supported with OpenVINO models
|
| 267 |
-
|
| 268 |
-
To enable pipeline parallelism, set the `model_args` of `pipeline_parallel`. In addition, you also have to set up `device` to value `HETERO:<GPU index1>,<GPU index2>` for example `HETERO:GPU.1,GPU.0` For example, the command to use pipeline parallelism of 2 is:
|
| 269 |
-
|
| 270 |
-
```bash
|
| 271 |
-
lm_eval --model openvino \
|
| 272 |
-
--tasks wikitext \
|
| 273 |
-
--model_args pretrained=<path_to_ov_model>,pipeline_parallel=True \
|
| 274 |
-
--device HETERO:GPU.1,GPU.0
|
| 275 |
-
```
|
| 276 |
-
|
| 277 |
-
### Tensor + Data Parallel and Optimized Inference with `vLLM`
|
| 278 |
-
|
| 279 |
-
We also support vLLM for faster inference on [supported model types](https://docs.vllm.ai/en/latest/models/supported_models.html), especially faster when splitting a model across multiple GPUs. For single-GPU or multi-GPU — tensor parallel, data parallel, or a combination of both — inference, for example:
|
| 280 |
-
|
| 281 |
-
```bash
|
| 282 |
-
lm_eval --model vllm \
|
| 283 |
-
--model_args pretrained={model_name},tensor_parallel_size={GPUs_per_model},dtype=auto,gpu_memory_utilization=0.8,data_parallel_size={model_replicas} \
|
| 284 |
-
--tasks lambada_openai \
|
| 285 |
-
--batch_size auto
|
| 286 |
-
```
|
| 287 |
-
|
| 288 |
-
To use vllm, do `pip install lm_eval[vllm]`. For a full list of supported vLLM configurations, please reference our [vLLM integration](https://github.com/EleutherAI/lm-evaluation-harness/blob/e74ec966556253fbe3d8ecba9de675c77c075bce/lm_eval/models/vllm_causallms.py) and the vLLM documentation.
|
| 289 |
-
|
| 290 |
-
vLLM occasionally differs in output from Huggingface. We treat Huggingface as the reference implementation, and provide a [script](./scripts/model_comparator.py) for checking the validity of vllm results against HF.
|
| 291 |
-
|
| 292 |
-
> [!Tip]
|
| 293 |
-
> For fastest performance, we recommend using `--batch_size auto` for vLLM whenever possible, to leverage its continuous batching functionality!
|
| 294 |
-
|
| 295 |
-
> [!Tip]
|
| 296 |
-
> Passing `max_model_len=4096` or some other reasonable default to vLLM through model args may cause speedups or prevent out-of-memory errors when trying to use auto batch size, such as for Mistral-7B-v0.1 which defaults to a maximum length of 32k.
|
| 297 |
-
|
| 298 |
-
### Tensor + Data Parallel and Fast Offline Batching Inference with `SGLang`
|
| 299 |
-
|
| 300 |
-
We support SGLang for efficient offline batch inference. Its **[Fast Backend Runtime](https://docs.sglang.ai/index.html)** delivers high performance through optimized memory management and parallel processing techniques. Key features include tensor parallelism, continuous batching, and support for various quantization methods (FP8/INT4/AWQ/GPTQ).
|
| 301 |
-
|
| 302 |
-
To use SGLang as the evaluation backend, please **install it in advance** via SGLang documents [here](https://docs.sglang.ai/start/install.html#install-sglang).
|
| 303 |
-
|
| 304 |
-
> [!Tip]
|
| 305 |
-
> Due to the installing method of [`Flashinfer`](https://docs.flashinfer.ai/)-- a fast attention kernel library, we don't include the dependencies of `SGLang` within [pyproject.toml](pyproject.toml). Note that the `Flashinfer` also has some requirements on `torch` version.
|
| 306 |
-
|
| 307 |
-
SGLang's server arguments are slightly different from other backends, see [here](https://docs.sglang.ai/backend/server_arguments.html) for more information. We provide an example of the usage here:
|
| 308 |
-
|
| 309 |
-
```bash
|
| 310 |
-
lm_eval --model sglang \
|
| 311 |
-
--model_args pretrained={model_name},dp_size={data_parallel_size},tp_size={tensor_parallel_size},dtype=auto \
|
| 312 |
-
--tasks gsm8k_cot \
|
| 313 |
-
--batch_size auto
|
| 314 |
-
```
|
| 315 |
-
|
| 316 |
-
> [!Tip]
|
| 317 |
-
> When encountering out of memory (OOM) errors (especially for multiple-choice tasks), try these solutions:
|
| 318 |
-
>
|
| 319 |
-
> 1. Use a manual `batch_size`, rather than `auto`.
|
| 320 |
-
> 2. Lower KV cache pool memory usage by adjusting `mem_fraction_static` - Add to your model arguments for example `--model_args pretrained=...,mem_fraction_static=0.7`.
|
| 321 |
-
> 3. Increase tensor parallel size `tp_size` (if using multiple GPUs).
|
| 322 |
-
|
| 323 |
-
### Model APIs and Inference Servers
|
| 324 |
-
|
| 325 |
-
Our library also supports the evaluation of models served via several commercial APIs, and we hope to implement support for the most commonly used performant local/self-hosted inference servers.
|
| 326 |
-
|
| 327 |
-
To call a hosted model, use:
|
| 328 |
-
|
| 329 |
-
```bash
|
| 330 |
-
export OPENAI_API_KEY=YOUR_KEY_HERE
|
| 331 |
-
lm_eval --model openai-completions \
|
| 332 |
-
--model_args model=davinci-002 \
|
| 333 |
-
--tasks lambada_openai,hellaswag
|
| 334 |
-
```
|
| 335 |
-
|
| 336 |
-
We also support using your own local inference server with servers that mirror the OpenAI Completions and ChatCompletions APIs.
|
| 337 |
-
|
| 338 |
-
```bash
|
| 339 |
-
lm_eval --model local-completions --tasks gsm8k --model_args model=facebook/opt-125m,base_url=http://{yourip}:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False,batch_size=16
|
| 340 |
-
```
|
| 341 |
-
|
| 342 |
-
Note that for externally hosted models, configs such as `--device` which relate to where to place a local model should not be used and do not function. Just like you can use `--model_args` to pass arbitrary arguments to the model constructor for local models, you can use it to pass arbitrary arguments to the model API for hosted models. See the documentation of the hosting service for information on what arguments they support.
|
| 343 |
-
|
| 344 |
-
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|
| 345 |
-
| --------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
|
| 346 |
-
| OpenAI Completions | :heavy_check_mark: | `openai-completions`, `local-completions` | All OpenAI Completions API models | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 347 |
-
| OpenAI ChatCompletions | :heavy_check_mark: | `openai-chat-completions`, `local-chat-completions` | [All ChatCompletions API models](https://platform.openai.com/docs/guides/gpt) | `generate_until` (no logprobs) |
|
| 348 |
-
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
|
| 349 |
-
| Anthropic Chat | :heavy_check_mark: | `anthropic-chat`, `anthropic-chat-completions` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/docs/models-overview) | `generate_until` (no logprobs) |
|
| 350 |
-
| Textsynth | :heavy_check_mark: | `textsynth` | [All supported engines](https://textsynth.com/documentation.html#engines) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 351 |
-
| Cohere | [:hourglass: - blocked on Cohere API bug](https://github.com/EleutherAI/lm-evaluation-harness/pull/395) | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 352 |
-
| [Llama.cpp](https://github.com/ggerganov/llama.cpp) (via [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)) | :heavy_check_mark: | `gguf`, `ggml` | [All models supported by llama.cpp](https://github.com/ggerganov/llama.cpp) | `generate_until`, `loglikelihood`, (perplexity evaluation not yet implemented) |
|
| 353 |
-
| vLLM | :heavy_check_mark: | `vllm` | [Most HF Causal Language Models](https://docs.vllm.ai/en/latest/models/supported_models.html) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 354 |
-
| Mamba | :heavy_check_mark: | `mamba_ssm` | [Mamba architecture Language Models via the `mamba_ssm` package](https://huggingface.co/state-spaces) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 355 |
-
| Huggingface Optimum (Causal LMs) | :heavy_check_mark: | `openvino` | Any decoder-only AutoModelForCausalLM converted with Huggingface Optimum into OpenVINO™ Intermediate Representation (IR) format | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 356 |
-
| Huggingface Optimum-intel IPEX (Causal LMs) | :heavy_check_mark: | `ipex` | Any decoder-only AutoModelForCausalLM | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 357 |
-
| Neuron via AWS Inf2 (Causal LMs) | :heavy_check_mark: | `neuronx` | Any decoder-only AutoModelForCausalLM supported to run on [huggingface-ami image for inferentia2](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 358 |
-
| [Neural Magic DeepSparse](https://github.com/neuralmagic/deepsparse) | :heavy_check_mark: | `deepsparse` | Any LM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub with the "deepsparse" tag](https://huggingface.co/models?other=deepsparse) | `generate_until`, `loglikelihood` |
|
| 359 |
-
| [Neural Magic SparseML](https://github.com/neuralmagic/sparseml) | :heavy_check_mark: | `sparseml` | Any decoder-only AutoModelForCausalLM from [SparseZoo](https://sparsezoo.neuralmagic.com/) or on [HF Hub](https://huggingface.co/neuralmagic). Especially useful for models with quantization like [`zoo:llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized`](https://sparsezoo.neuralmagic.com/models/llama2-7b-gsm8k_llama2_pretrain-pruned60_quantized) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 360 |
-
| NVIDIA NeMo | :heavy_check_mark: | `nemo_lm` | [All supported models](https://docs.nvidia.com/nemo-framework/user-guide/24.09/nemotoolkit/core/core.html#nemo-models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 361 |
-
| Watsonx.ai | :heavy_check_mark: | `watsonx_llm` | [Supported Watsonx.ai Engines](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx) | `generate_until` `loglikelihood` |
|
| 362 |
-
| [Your local inference server!](docs/API_guide.md) | :heavy_check_mark: | `local-completions` or `local-chat-completions` | Support for OpenAI API-compatible servers, with easy customization for other APIs. | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
|
| 363 |
-
|
| 364 |
-
Models which do not supply logits or logprobs can be used with tasks of type `generate_until` only, while local models, or APIs that supply logprobs/logits of their prompts, can be run on all task types: `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
|
| 365 |
-
|
| 366 |
-
For more information on the different task `output_types` and model request types, see [our documentation](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md#interface).
|
| 367 |
-
|
| 368 |
-
> [!Note]
|
| 369 |
-
> For best performance with closed chat model APIs such as Anthropic Claude 3 and GPT-4, we recommend carefully looking at a few sample outputs using `--limit 10` first to confirm answer extraction and scoring on generative tasks is performing as expected. providing `system="<some system prompt here>"` within `--model_args` for anthropic-chat-completions, to instruct the model what format to respond in, may be useful.
|
| 370 |
-
|
| 371 |
-
### Other Frameworks
|
| 372 |
-
|
| 373 |
-
A number of other libraries contain scripts for calling the eval harness through their library. These include [GPT-NeoX](https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py), [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/blob/main/examples/MoE/readme_evalharness.md), and [mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/eval_harness.py).
|
| 374 |
-
|
| 375 |
-
To create your own custom integration you can follow instructions from [this tutorial](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage).
|
| 376 |
-
|
| 377 |
-
### Additional Features
|
| 378 |
-
|
| 379 |
-
> [!Note]
|
| 380 |
-
> For tasks unsuitable for direct evaluation — either due risks associated with executing untrusted code or complexities in the evaluation process — the `--predict_only` flag is available to obtain decoded generations for post-hoc evaluation.
|
| 381 |
-
|
| 382 |
-
If you have a Metal compatible Mac, you can run the eval harness using the MPS back-end by replacing `--device cuda:0` with `--device mps` (requires PyTorch version 2.1 or higher). **Note that the PyTorch MPS backend is still in early stages of development, so correctness issues or unsupported operations may exist. If you observe oddities in model performance on the MPS back-end, we recommend first checking that a forward pass of your model on `--device cpu` and `--device mps` match.**
|
| 383 |
-
|
| 384 |
-
> [!Note]
|
| 385 |
-
> You can inspect what the LM inputs look like by running the following command:
|
| 386 |
-
>
|
| 387 |
-
> ```bash
|
| 388 |
-
> python write_out.py \
|
| 389 |
-
> --tasks <task1,task2,...> \
|
| 390 |
-
> --num_fewshot 5 \
|
| 391 |
-
> --num_examples 10 \
|
| 392 |
-
> --output_base_path /path/to/output/folder
|
| 393 |
-
> ```
|
| 394 |
-
>
|
| 395 |
-
> This will write out one text file for each task.
|
| 396 |
-
|
| 397 |
-
To verify the data integrity of the tasks you're performing in addition to running the tasks themselves, you can use the `--check_integrity` flag:
|
| 398 |
-
|
| 399 |
-
```bash
|
| 400 |
-
lm_eval --model openai \
|
| 401 |
-
--model_args engine=davinci-002 \
|
| 402 |
-
--tasks lambada_openai,hellaswag \
|
| 403 |
-
--check_integrity
|
| 404 |
-
```
|
| 405 |
-
|
| 406 |
-
## Advanced Usage Tips
|
| 407 |
-
|
| 408 |
-
For models loaded with the HuggingFace `transformers` library, any arguments provided via `--model_args` get passed to the relevant constructor directly. This means that anything you can do with `AutoModel` can be done with our library. For example, you can pass a local path via `pretrained=` or use models finetuned with [PEFT](https://github.com/huggingface/peft) by taking the call you would run to evaluate the base model and add `,peft=PATH` to the `model_args` argument:
|
| 409 |
-
|
| 410 |
-
```bash
|
| 411 |
-
lm_eval --model hf \
|
| 412 |
-
--model_args pretrained=EleutherAI/gpt-j-6b,parallelize=True,load_in_4bit=True,peft=nomic-ai/gpt4all-j-lora \
|
| 413 |
-
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
|
| 414 |
-
--device cuda:0
|
| 415 |
-
```
|
| 416 |
-
|
| 417 |
-
Models provided as delta weights can be easily loaded using the Hugging Face transformers library. Within --model_args, set the delta argument to specify the delta weights, and use the pretrained argument to designate the relative base model to which they will be applied:
|
| 418 |
-
|
| 419 |
-
```bash
|
| 420 |
-
lm_eval --model hf \
|
| 421 |
-
--model_args pretrained=Ejafa/llama_7B,delta=lmsys/vicuna-7b-delta-v1.1 \
|
| 422 |
-
--tasks hellaswag
|
| 423 |
-
```
|
| 424 |
-
|
| 425 |
-
GPTQ quantized models can be loaded using [GPTQModel](https://github.com/ModelCloud/GPTQModel) (faster) or [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)
|
| 426 |
-
|
| 427 |
-
GPTQModel: add `,gptqmodel=True` to `model_args`
|
| 428 |
-
|
| 429 |
-
```bash
|
| 430 |
-
lm_eval --model hf \
|
| 431 |
-
--model_args pretrained=model-name-or-path,gptqmodel=True \
|
| 432 |
-
--tasks hellaswag
|
| 433 |
-
```
|
| 434 |
-
|
| 435 |
-
AutoGPTQ: add `,autogptq=True` to `model_args`:
|
| 436 |
-
|
| 437 |
-
```bash
|
| 438 |
-
lm_eval --model hf \
|
| 439 |
-
--model_args pretrained=model-name-or-path,autogptq=model.safetensors,gptq_use_triton=True \
|
| 440 |
-
--tasks hellaswag
|
| 441 |
-
```
|
| 442 |
-
|
| 443 |
-
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
|
| 444 |
-
|
| 445 |
-
## Saving & Caching Results
|
| 446 |
-
|
| 447 |
-
To save evaluation results provide an `--output_path`. We also support logging model responses with the `--log_samples` flag for post-hoc analysis.
|
| 448 |
-
|
| 449 |
-
> [!TIP]
|
| 450 |
-
> Use `--use_cache <DIR>` to cache evaluation results and skip previously evaluated samples when resuming runs of the same (model, task) pairs. Note that caching is rank-dependent, so restart with the same GPU count if interrupted. You can also use --cache_requests to save dataset preprocessing steps for faster evaluation resumption.
|
| 451 |
-
|
| 452 |
-
To push results and samples to the Hugging Face Hub, first ensure an access token with write access is set in the `HF_TOKEN` environment variable. Then, use the `--hf_hub_log_args` flag to specify the organization, repository name, repository visibility, and whether to push results and samples to the Hub - [example dataset on the HF Hub](https://huggingface.co/datasets/KonradSzafer/lm-eval-results-demo). For instance:
|
| 453 |
-
|
| 454 |
-
```bash
|
| 455 |
-
lm_eval --model hf \
|
| 456 |
-
--model_args pretrained=model-name-or-path,autogptq=model.safetensors,gptq_use_triton=True \
|
| 457 |
-
--tasks hellaswag \
|
| 458 |
-
--log_samples \
|
| 459 |
-
--output_path results \
|
| 460 |
-
--hf_hub_log_args hub_results_org=EleutherAI,hub_repo_name=lm-eval-results,push_results_to_hub=True,push_samples_to_hub=True,public_repo=False \
|
| 461 |
-
```
|
| 462 |
-
|
| 463 |
-
This allows you to easily download the results and samples from the Hub, using:
|
| 464 |
-
|
| 465 |
-
```python
|
| 466 |
-
from datasets import load_dataset
|
| 467 |
-
|
| 468 |
-
load_dataset("EleutherAI/lm-eval-results-private", "hellaswag", "latest")
|
| 469 |
-
```
|
| 470 |
-
|
| 471 |
-
For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation!
|
| 472 |
-
|
| 473 |
-
## Visualizing Results
|
| 474 |
-
|
| 475 |
-
You can seamlessly visualize and analyze the results of your evaluation harness runs using both Weights & Biases (W&B) and Zeno.
|
| 476 |
-
|
| 477 |
-
### Zeno
|
| 478 |
-
|
| 479 |
-
You can use [Zeno](https://zenoml.com) to visualize the results of your eval harness runs.
|
| 480 |
-
|
| 481 |
-
First, head to [hub.zenoml.com](https://hub.zenoml.com) to create an account and get an API key [on your account page](https://hub.zenoml.com/account).
|
| 482 |
-
Add this key as an environment variable:
|
| 483 |
-
|
| 484 |
-
```bash
|
| 485 |
-
export ZENO_API_KEY=[your api key]
|
| 486 |
-
```
|
| 487 |
-
|
| 488 |
-
You'll also need to install the `lm_eval[zeno]` package extra.
|
| 489 |
-
|
| 490 |
-
To visualize the results, run the eval harness with the `log_samples` and `output_path` flags.
|
| 491 |
-
We expect `output_path` to contain multiple folders that represent individual model names.
|
| 492 |
-
You can thus run your evaluation on any number of tasks and models and upload all of the results as projects on Zeno.
|
| 493 |
-
|
| 494 |
-
```bash
|
| 495 |
-
lm_eval \
|
| 496 |
-
--model hf \
|
| 497 |
-
--model_args pretrained=EleutherAI/gpt-j-6B \
|
| 498 |
-
--tasks hellaswag \
|
| 499 |
-
--device cuda:0 \
|
| 500 |
-
--batch_size 8 \
|
| 501 |
-
--log_samples \
|
| 502 |
-
--output_path output/gpt-j-6B
|
| 503 |
-
```
|
| 504 |
-
|
| 505 |
-
Then, you can upload the resulting data using the `zeno_visualize` script:
|
| 506 |
-
|
| 507 |
-
```bash
|
| 508 |
-
python scripts/zeno_visualize.py \
|
| 509 |
-
--data_path output \
|
| 510 |
-
--project_name "Eleuther Project"
|
| 511 |
-
```
|
| 512 |
-
|
| 513 |
-
This will use all subfolders in `data_path` as different models and upload all tasks within these model folders to Zeno.
|
| 514 |
-
If you run the eval harness on multiple tasks, the `project_name` will be used as a prefix and one project will be created per task.
|
| 515 |
-
|
| 516 |
-
You can find an example of this workflow in [examples/visualize-zeno.ipynb](examples/visualize-zeno.ipynb).
|
| 517 |
-
|
| 518 |
-
### Weights and Biases
|
| 519 |
-
|
| 520 |
-
With the [Weights and Biases](https://wandb.ai/site) integration, you can now spend more time extracting deeper insights into your evaluation results. The integration is designed to streamline the process of logging and visualizing experiment results using the Weights & Biases (W&B) platform.
|
| 521 |
-
|
| 522 |
-
The integration provide functionalities
|
| 523 |
-
|
| 524 |
-
- to automatically log the evaluation results,
|
| 525 |
-
- log the samples as W&B Tables for easy visualization,
|
| 526 |
-
- log the `results.json` file as an artifact for version control,
|
| 527 |
-
- log the `<task_name>_eval_samples.json` file if the samples are logged,
|
| 528 |
-
- generate a comprehensive report for analysis and visualization with all the important metric,
|
| 529 |
-
- log task and cli specific configs,
|
| 530 |
-
- and more out of the box like the command used to run the evaluation, GPU/CPU counts, timestamp, etc.
|
| 531 |
-
|
| 532 |
-
First you'll need to install the lm_eval[wandb] package extra. Do `pip install lm_eval[wandb]`.
|
| 533 |
-
|
| 534 |
-
Authenticate your machine with an your unique W&B token. Visit https://wandb.ai/authorize to get one. Do `wandb login` in your command line terminal.
|
| 535 |
-
|
| 536 |
-
Run eval harness as usual with a `wandb_args` flag. Use this flag to provide arguments for initializing a wandb run ([wandb.init](https://docs.wandb.ai/ref/python/init)) as comma separated string arguments.
|
| 537 |
-
|
| 538 |
-
```bash
|
| 539 |
-
lm_eval \
|
| 540 |
-
--model hf \
|
| 541 |
-
--model_args pretrained=microsoft/phi-2,trust_remote_code=True \
|
| 542 |
-
--tasks hellaswag,mmlu_abstract_algebra \
|
| 543 |
-
--device cuda:0 \
|
| 544 |
-
--batch_size 8 \
|
| 545 |
-
--output_path output/phi-2 \
|
| 546 |
-
--limit 10 \
|
| 547 |
-
--wandb_args project=lm-eval-harness-integration \
|
| 548 |
-
--log_samples
|
| 549 |
-
```
|
| 550 |
-
|
| 551 |
-
In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb), and an example of how to integrate it beyond the CLI.
|
| 552 |
-
|
| 553 |
-
## How to Contribute or Learn More?
|
| 554 |
-
|
| 555 |
-
For more information on the library and how everything fits together, check out all of our [documentation pages](https://github.com/EleutherAI/lm-evaluation-harness/tree/main/docs)! We plan to post a larger roadmap of desired + planned library improvements soon, with more information on how contributors can help.
|
| 556 |
-
|
| 557 |
-
### Implementing new tasks
|
| 558 |
-
|
| 559 |
-
To implement a new task in the eval harness, see [this guide](./docs/new_task_guide.md).
|
| 560 |
-
|
| 561 |
-
In general, we follow this priority list for addressing concerns about prompting and other eval details:
|
| 562 |
-
|
| 563 |
-
1. If there is widespread agreement among people who train LLMs, use the agreed upon procedure.
|
| 564 |
-
2. If there is a clear and unambiguous official implementation, use that procedure.
|
| 565 |
-
3. If there is widespread agreement among people who evaluate LLMs, use the agreed upon procedure.
|
| 566 |
-
4. If there are multiple common implementations but not universal or widespread agreement, use our preferred option among the common implementations. As before, prioritize choosing from among the implementations found in LLM training papers.
|
| 567 |
-
|
| 568 |
-
These are guidelines and not rules, and can be overruled in special circumstances.
|
| 569 |
-
|
| 570 |
-
We try to prioritize agreement with the procedures used by other groups to decrease the harm when people inevitably compare runs across different papers despite our discouragement of the practice. Historically, we also prioritized the implementation from [Language Models are Few Shot Learners](https://arxiv.org/abs/2005.14165) as our original goal was specifically to compare results with that paper.
|
| 571 |
-
|
| 572 |
-
### Support
|
| 573 |
-
|
| 574 |
-
The best way to get support is to open an issue on this repo or join the [EleutherAI Discord server](https://discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases. If you've used the library and have had a positive (or negative) experience, we'd love to hear from you!
|
| 575 |
-
|
| 576 |
-
## Optional Extras
|
| 577 |
-
|
| 578 |
-
Extras dependencies can be installed via `pip install -e ".[NAME]"`
|
| 579 |
-
|
| 580 |
-
| Name | Use |
|
| 581 |
-
| -------------------- | ----------------------------------------------------- |
|
| 582 |
-
| api | For using api models (Anthropic, OpenAI API) |
|
| 583 |
-
| audiolm_qwen | For running Qwen2 audio models |
|
| 584 |
-
| deepsparse | For running NM's DeepSparse models |
|
| 585 |
-
| dev | For linting PRs and contributions |
|
| 586 |
-
| gptq | For loading models with AutoGPTQ |
|
| 587 |
-
| gptqmodel | For loading models with GPTQModel |
|
| 588 |
-
| hf_transfer | For speeding up HF Hub file downloads |
|
| 589 |
-
| ibm_watsonx_ai | For using IBM watsonx.ai model apis |
|
| 590 |
-
| ifeval | For running the IFEval task |
|
| 591 |
-
| ipex | For running on optimum-intel ipex backend |
|
| 592 |
-
| japanese_leaderboard | For running Japanese LLM Leaderboard tasks |
|
| 593 |
-
| longbench | For running LongBench tasks |
|
| 594 |
-
| mamba | For loading Mamba SSM models |
|
| 595 |
-
| math | For running math task answer checking |
|
| 596 |
-
| multilingual | For multilingual tokenizers |
|
| 597 |
-
| neuronx | For running on AWS inf2 instances |
|
| 598 |
-
| optimum | For running Intel OpenVINO models |
|
| 599 |
-
| promptsource | For using PromptSource prompts |
|
| 600 |
-
| ruler | For running RULER tasks |
|
| 601 |
-
| sae_lens | For using SAELens to steer models |
|
| 602 |
-
| sentencepiece | For using the sentencepiece tokenizer |
|
| 603 |
-
| sparseml | For using NM's SparseML models |
|
| 604 |
-
| sparsify | For using Sparsify to steer models |
|
| 605 |
-
| testing | For running library test suite |
|
| 606 |
-
| vllm | For loading models with vLLM |
|
| 607 |
-
| wandb | For integration with `Weights and Biases` platform |
|
| 608 |
-
| zeno | For visualizing results with Zeno |
|
| 609 |
-
| -------------------- | ----------------------------------------------------- |
|
| 610 |
-
| all | Loads all extras (not recommended) |
|
| 611 |
-
|
| 612 |
-
## Cite as
|
| 613 |
-
|
| 614 |
-
```text
|
| 615 |
-
@misc{eval-harness,
|
| 616 |
-
author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy},
|
| 617 |
-
title = {The Language Model Evaluation Harness},
|
| 618 |
-
month = 07,
|
| 619 |
-
year = 2024,
|
| 620 |
-
publisher = {Zenodo},
|
| 621 |
-
version = {v0.4.3},
|
| 622 |
-
doi = {10.5281/zenodo.12608602},
|
| 623 |
-
url = {https://zenodo.org/records/12608602}
|
| 624 |
-
}
|
| 625 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/ignore.txt
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
ROUGE
|
| 2 |
-
rouge
|
| 3 |
-
nin
|
| 4 |
-
maka
|
| 5 |
-
mor
|
| 6 |
-
te
|
| 7 |
-
ond
|
| 8 |
-
extraversion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/__init__.py
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
|
| 4 |
-
from .evaluator import evaluate, simple_evaluate
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
__version__ = "0.4.8"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/__main__.py
DELETED
|
@@ -1,530 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
import os
|
| 5 |
-
import sys
|
| 6 |
-
from functools import partial
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from typing import Union
|
| 9 |
-
|
| 10 |
-
from lm_eval import evaluator, utils
|
| 11 |
-
from lm_eval.evaluator import request_caching_arg_to_dict
|
| 12 |
-
from lm_eval.loggers import EvaluationTracker, WandbLogger
|
| 13 |
-
from lm_eval.tasks import TaskManager
|
| 14 |
-
from lm_eval.utils import (
|
| 15 |
-
handle_non_serializable,
|
| 16 |
-
make_table,
|
| 17 |
-
simple_parse_args_string,
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def try_parse_json(value: str) -> Union[str, dict, None]:
|
| 22 |
-
if value is None:
|
| 23 |
-
return None
|
| 24 |
-
try:
|
| 25 |
-
return json.loads(value)
|
| 26 |
-
except json.JSONDecodeError:
|
| 27 |
-
if "{" in value:
|
| 28 |
-
raise argparse.ArgumentTypeError(
|
| 29 |
-
f"Invalid JSON: {value}. Hint: Use double quotes for JSON strings."
|
| 30 |
-
)
|
| 31 |
-
return value
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def _int_or_none_list_arg_type(
|
| 35 |
-
min_len: int, max_len: int, defaults: str, value: str, split_char: str = ","
|
| 36 |
-
):
|
| 37 |
-
def parse_value(item):
|
| 38 |
-
item = item.strip().lower()
|
| 39 |
-
if item == "none":
|
| 40 |
-
return None
|
| 41 |
-
try:
|
| 42 |
-
return int(item)
|
| 43 |
-
except ValueError:
|
| 44 |
-
raise argparse.ArgumentTypeError(f"{item} is not an integer or None")
|
| 45 |
-
|
| 46 |
-
items = [parse_value(v) for v in value.split(split_char)]
|
| 47 |
-
num_items = len(items)
|
| 48 |
-
|
| 49 |
-
if num_items == 1:
|
| 50 |
-
# Makes downstream handling the same for single and multiple values
|
| 51 |
-
items = items * max_len
|
| 52 |
-
elif num_items < min_len or num_items > max_len:
|
| 53 |
-
raise argparse.ArgumentTypeError(
|
| 54 |
-
f"Argument requires {max_len} integers or None, separated by '{split_char}'"
|
| 55 |
-
)
|
| 56 |
-
elif num_items != max_len:
|
| 57 |
-
logging.warning(
|
| 58 |
-
f"Argument requires {max_len} integers or None, separated by '{split_char}'. "
|
| 59 |
-
"Missing values will be filled with defaults."
|
| 60 |
-
)
|
| 61 |
-
default_items = [parse_value(v) for v in defaults.split(split_char)]
|
| 62 |
-
items.extend(
|
| 63 |
-
default_items[num_items:]
|
| 64 |
-
) # extend items list with missing defaults
|
| 65 |
-
|
| 66 |
-
return items
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def check_argument_types(parser: argparse.ArgumentParser):
|
| 70 |
-
"""
|
| 71 |
-
Check to make sure all CLI args are typed, raises error if not
|
| 72 |
-
"""
|
| 73 |
-
for action in parser._actions:
|
| 74 |
-
if action.dest != "help" and not action.const:
|
| 75 |
-
if action.type is None:
|
| 76 |
-
raise ValueError(
|
| 77 |
-
f"Argument '{action.dest}' doesn't have a type specified."
|
| 78 |
-
)
|
| 79 |
-
else:
|
| 80 |
-
continue
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def setup_parser() -> argparse.ArgumentParser:
|
| 84 |
-
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
| 85 |
-
parser.add_argument(
|
| 86 |
-
"--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
|
| 87 |
-
)
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"--tasks",
|
| 90 |
-
"-t",
|
| 91 |
-
default=None,
|
| 92 |
-
type=str,
|
| 93 |
-
metavar="task1,task2",
|
| 94 |
-
help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
|
| 95 |
-
)
|
| 96 |
-
parser.add_argument(
|
| 97 |
-
"--model_args",
|
| 98 |
-
"-a",
|
| 99 |
-
default="",
|
| 100 |
-
type=try_parse_json,
|
| 101 |
-
help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'""",
|
| 102 |
-
)
|
| 103 |
-
parser.add_argument(
|
| 104 |
-
"--num_fewshot",
|
| 105 |
-
"-f",
|
| 106 |
-
type=int,
|
| 107 |
-
default=None,
|
| 108 |
-
metavar="N",
|
| 109 |
-
help="Number of examples in few-shot context",
|
| 110 |
-
)
|
| 111 |
-
parser.add_argument(
|
| 112 |
-
"--batch_size",
|
| 113 |
-
"-b",
|
| 114 |
-
type=str,
|
| 115 |
-
default=1,
|
| 116 |
-
metavar="auto|auto:N|N",
|
| 117 |
-
help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
|
| 118 |
-
)
|
| 119 |
-
parser.add_argument(
|
| 120 |
-
"--max_batch_size",
|
| 121 |
-
type=int,
|
| 122 |
-
default=None,
|
| 123 |
-
metavar="N",
|
| 124 |
-
help="Maximal batch size to try with --batch_size auto.",
|
| 125 |
-
)
|
| 126 |
-
parser.add_argument(
|
| 127 |
-
"--device",
|
| 128 |
-
type=str,
|
| 129 |
-
default=None,
|
| 130 |
-
help="Device to use (e.g. cuda, cuda:0, cpu).",
|
| 131 |
-
)
|
| 132 |
-
parser.add_argument(
|
| 133 |
-
"--output_path",
|
| 134 |
-
"-o",
|
| 135 |
-
default=None,
|
| 136 |
-
type=str,
|
| 137 |
-
metavar="DIR|DIR/file.json",
|
| 138 |
-
help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
|
| 139 |
-
)
|
| 140 |
-
parser.add_argument(
|
| 141 |
-
"--limit",
|
| 142 |
-
"-L",
|
| 143 |
-
type=float,
|
| 144 |
-
default=None,
|
| 145 |
-
metavar="N|0<N<1",
|
| 146 |
-
help="Limit the number of examples per task. "
|
| 147 |
-
"If <1, limit is a percentage of the total number of examples.",
|
| 148 |
-
)
|
| 149 |
-
parser.add_argument(
|
| 150 |
-
"--samples",
|
| 151 |
-
"-E",
|
| 152 |
-
default=None,
|
| 153 |
-
type=str,
|
| 154 |
-
metavar="/path/to/json",
|
| 155 |
-
help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
|
| 156 |
-
)
|
| 157 |
-
parser.add_argument(
|
| 158 |
-
"--use_cache",
|
| 159 |
-
"-c",
|
| 160 |
-
type=str,
|
| 161 |
-
default=None,
|
| 162 |
-
metavar="DIR",
|
| 163 |
-
help="A path to a sqlite db file for caching model responses. `None` if not caching.",
|
| 164 |
-
)
|
| 165 |
-
parser.add_argument(
|
| 166 |
-
"--cache_requests",
|
| 167 |
-
type=str,
|
| 168 |
-
default=None,
|
| 169 |
-
choices=["true", "refresh", "delete"],
|
| 170 |
-
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
|
| 171 |
-
)
|
| 172 |
-
parser.add_argument(
|
| 173 |
-
"--check_integrity",
|
| 174 |
-
action="store_true",
|
| 175 |
-
help="Whether to run the relevant part of the test suite for the tasks.",
|
| 176 |
-
)
|
| 177 |
-
parser.add_argument(
|
| 178 |
-
"--write_out",
|
| 179 |
-
"-w",
|
| 180 |
-
action="store_true",
|
| 181 |
-
default=False,
|
| 182 |
-
help="Prints the prompt for the first few documents.",
|
| 183 |
-
)
|
| 184 |
-
parser.add_argument(
|
| 185 |
-
"--log_samples",
|
| 186 |
-
"-s",
|
| 187 |
-
action="store_true",
|
| 188 |
-
default=False,
|
| 189 |
-
help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
|
| 190 |
-
)
|
| 191 |
-
parser.add_argument(
|
| 192 |
-
"--system_instruction",
|
| 193 |
-
type=str,
|
| 194 |
-
default=None,
|
| 195 |
-
help="System instruction to be used in the prompt",
|
| 196 |
-
)
|
| 197 |
-
parser.add_argument(
|
| 198 |
-
"--apply_chat_template",
|
| 199 |
-
type=str,
|
| 200 |
-
nargs="?",
|
| 201 |
-
const=True,
|
| 202 |
-
default=False,
|
| 203 |
-
help=(
|
| 204 |
-
"If True, apply chat template to the prompt. "
|
| 205 |
-
"Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
|
| 206 |
-
"To apply a specific template from the available list of templates, provide the template name as an argument. "
|
| 207 |
-
"E.g. `--apply_chat_template template_name`"
|
| 208 |
-
),
|
| 209 |
-
)
|
| 210 |
-
parser.add_argument(
|
| 211 |
-
"--fewshot_as_multiturn",
|
| 212 |
-
action="store_true",
|
| 213 |
-
default=False,
|
| 214 |
-
help="If True, uses the fewshot as a multi-turn conversation",
|
| 215 |
-
)
|
| 216 |
-
parser.add_argument(
|
| 217 |
-
"--show_config",
|
| 218 |
-
action="store_true",
|
| 219 |
-
default=False,
|
| 220 |
-
help="If True, shows the the full config of all tasks at the end of the evaluation.",
|
| 221 |
-
)
|
| 222 |
-
parser.add_argument(
|
| 223 |
-
"--include_path",
|
| 224 |
-
type=str,
|
| 225 |
-
default=None,
|
| 226 |
-
metavar="DIR",
|
| 227 |
-
help="Additional path to include if there are external tasks to include.",
|
| 228 |
-
)
|
| 229 |
-
parser.add_argument(
|
| 230 |
-
"--gen_kwargs",
|
| 231 |
-
type=try_parse_json,
|
| 232 |
-
default=None,
|
| 233 |
-
help=(
|
| 234 |
-
"Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
|
| 235 |
-
""" e.g. '{"temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
|
| 236 |
-
),
|
| 237 |
-
)
|
| 238 |
-
parser.add_argument(
|
| 239 |
-
"--verbosity",
|
| 240 |
-
"-v",
|
| 241 |
-
type=str.upper,
|
| 242 |
-
default=None,
|
| 243 |
-
metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
|
| 244 |
-
help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
|
| 245 |
-
)
|
| 246 |
-
parser.add_argument(
|
| 247 |
-
"--wandb_args",
|
| 248 |
-
type=str,
|
| 249 |
-
default="",
|
| 250 |
-
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
|
| 251 |
-
)
|
| 252 |
-
parser.add_argument(
|
| 253 |
-
"--wandb_config_args",
|
| 254 |
-
type=str,
|
| 255 |
-
default="",
|
| 256 |
-
help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3",
|
| 257 |
-
)
|
| 258 |
-
parser.add_argument(
|
| 259 |
-
"--hf_hub_log_args",
|
| 260 |
-
type=str,
|
| 261 |
-
default="",
|
| 262 |
-
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
|
| 263 |
-
)
|
| 264 |
-
parser.add_argument(
|
| 265 |
-
"--predict_only",
|
| 266 |
-
"-x",
|
| 267 |
-
action="store_true",
|
| 268 |
-
default=False,
|
| 269 |
-
help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
|
| 270 |
-
)
|
| 271 |
-
default_seed_string = "0,1234,1234,1234"
|
| 272 |
-
parser.add_argument(
|
| 273 |
-
"--seed",
|
| 274 |
-
type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
|
| 275 |
-
default=default_seed_string, # for backward compatibility
|
| 276 |
-
help=(
|
| 277 |
-
"Set seed for python's random, numpy, torch, and fewshot sampling.\n"
|
| 278 |
-
"Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
|
| 279 |
-
"respectively, or a single integer to set the same seed for all four.\n"
|
| 280 |
-
f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
|
| 281 |
-
"(for backward compatibility).\n"
|
| 282 |
-
"E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
|
| 283 |
-
"Here numpy's seed is not set since the second value is `None`.\n"
|
| 284 |
-
"E.g, `--seed 42` sets all four seeds to 42."
|
| 285 |
-
),
|
| 286 |
-
)
|
| 287 |
-
parser.add_argument(
|
| 288 |
-
"--trust_remote_code",
|
| 289 |
-
action="store_true",
|
| 290 |
-
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
|
| 291 |
-
)
|
| 292 |
-
parser.add_argument(
|
| 293 |
-
"--confirm_run_unsafe_code",
|
| 294 |
-
action="store_true",
|
| 295 |
-
help="Confirm that you understand the risks of running unsafe code for tasks that require it",
|
| 296 |
-
)
|
| 297 |
-
parser.add_argument(
|
| 298 |
-
"--metadata",
|
| 299 |
-
type=json.loads,
|
| 300 |
-
default=None,
|
| 301 |
-
help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
|
| 302 |
-
)
|
| 303 |
-
return parser
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
|
| 307 |
-
check_argument_types(parser)
|
| 308 |
-
return parser.parse_args()
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
|
| 312 |
-
if not args:
|
| 313 |
-
# we allow for args to be passed externally, else we parse them ourselves
|
| 314 |
-
parser = setup_parser()
|
| 315 |
-
args = parse_eval_args(parser)
|
| 316 |
-
|
| 317 |
-
if args.wandb_args:
|
| 318 |
-
wandb_args_dict = simple_parse_args_string(args.wandb_args)
|
| 319 |
-
wandb_config_args_dict = simple_parse_args_string(args.wandb_config_args)
|
| 320 |
-
wandb_logger = WandbLogger(wandb_args_dict, wandb_config_args_dict)
|
| 321 |
-
|
| 322 |
-
utils.setup_logging(args.verbosity)
|
| 323 |
-
eval_logger = logging.getLogger(__name__)
|
| 324 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 325 |
-
|
| 326 |
-
# update the evaluation tracker args with the output path and the HF token
|
| 327 |
-
if args.output_path:
|
| 328 |
-
args.hf_hub_log_args += f",output_path={args.output_path}"
|
| 329 |
-
if os.environ.get("HF_TOKEN", None):
|
| 330 |
-
args.hf_hub_log_args += f",token={os.environ.get('HF_TOKEN')}"
|
| 331 |
-
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
|
| 332 |
-
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
|
| 333 |
-
|
| 334 |
-
if args.predict_only:
|
| 335 |
-
args.log_samples = True
|
| 336 |
-
if (args.log_samples or args.predict_only) and not args.output_path:
|
| 337 |
-
raise ValueError(
|
| 338 |
-
"Specify --output_path if providing --log_samples or --predict_only"
|
| 339 |
-
)
|
| 340 |
-
|
| 341 |
-
if args.fewshot_as_multiturn and args.apply_chat_template is False:
|
| 342 |
-
raise ValueError(
|
| 343 |
-
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set (either to `True` or to the chosen template name)."
|
| 344 |
-
)
|
| 345 |
-
|
| 346 |
-
if args.include_path is not None:
|
| 347 |
-
eval_logger.info(f"Including path: {args.include_path}")
|
| 348 |
-
metadata = (
|
| 349 |
-
simple_parse_args_string(args.model_args)
|
| 350 |
-
if isinstance(args.model_args, str)
|
| 351 |
-
else args.model_args
|
| 352 |
-
if isinstance(args.model_args, dict)
|
| 353 |
-
else {}
|
| 354 |
-
) | (
|
| 355 |
-
args.metadata
|
| 356 |
-
if isinstance(args.metadata, dict)
|
| 357 |
-
else simple_parse_args_string(args.metadata)
|
| 358 |
-
)
|
| 359 |
-
|
| 360 |
-
task_manager = TaskManager(include_path=args.include_path, metadata=metadata)
|
| 361 |
-
|
| 362 |
-
if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
|
| 363 |
-
eval_logger.warning(
|
| 364 |
-
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
if args.limit:
|
| 368 |
-
eval_logger.warning(
|
| 369 |
-
" --limit SHOULD ONLY BE USED FOR TESTING."
|
| 370 |
-
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
|
| 371 |
-
)
|
| 372 |
-
if args.samples:
|
| 373 |
-
assert args.limit is None, (
|
| 374 |
-
"If --samples is not None, then --limit must be None."
|
| 375 |
-
)
|
| 376 |
-
if (samples := Path(args.samples)).is_file():
|
| 377 |
-
args.samples = json.loads(samples.read_text())
|
| 378 |
-
else:
|
| 379 |
-
args.samples = json.loads(args.samples)
|
| 380 |
-
|
| 381 |
-
if args.tasks is None:
|
| 382 |
-
eval_logger.error("Need to specify task to evaluate.")
|
| 383 |
-
sys.exit()
|
| 384 |
-
elif args.tasks == "list":
|
| 385 |
-
print(task_manager.list_all_tasks())
|
| 386 |
-
sys.exit()
|
| 387 |
-
elif args.tasks == "list_groups":
|
| 388 |
-
print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
|
| 389 |
-
sys.exit()
|
| 390 |
-
elif args.tasks == "list_tags":
|
| 391 |
-
print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))
|
| 392 |
-
sys.exit()
|
| 393 |
-
elif args.tasks == "list_subtasks":
|
| 394 |
-
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
|
| 395 |
-
sys.exit()
|
| 396 |
-
else:
|
| 397 |
-
if os.path.isdir(args.tasks):
|
| 398 |
-
import glob
|
| 399 |
-
|
| 400 |
-
task_names = []
|
| 401 |
-
yaml_path = os.path.join(args.tasks, "*.yaml")
|
| 402 |
-
for yaml_file in glob.glob(yaml_path):
|
| 403 |
-
config = utils.load_yaml_config(yaml_file)
|
| 404 |
-
task_names.append(config)
|
| 405 |
-
else:
|
| 406 |
-
task_list = args.tasks.split(",")
|
| 407 |
-
task_names = task_manager.match_tasks(task_list)
|
| 408 |
-
for task in [task for task in task_list if task not in task_names]:
|
| 409 |
-
if os.path.isfile(task):
|
| 410 |
-
config = utils.load_yaml_config(task)
|
| 411 |
-
task_names.append(config)
|
| 412 |
-
task_missing = [
|
| 413 |
-
task for task in task_list if task not in task_names and "*" not in task
|
| 414 |
-
] # we don't want errors if a wildcard ("*") task name was used
|
| 415 |
-
|
| 416 |
-
if task_missing:
|
| 417 |
-
missing = ", ".join(task_missing)
|
| 418 |
-
eval_logger.error(
|
| 419 |
-
f"Tasks were not found: {missing}\n"
|
| 420 |
-
f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
|
| 421 |
-
)
|
| 422 |
-
raise ValueError(
|
| 423 |
-
f"Tasks not found: {missing}. Try `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above, or pass '--verbosity DEBUG' to troubleshoot task registration issues."
|
| 424 |
-
)
|
| 425 |
-
|
| 426 |
-
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
|
| 427 |
-
if args.trust_remote_code:
|
| 428 |
-
eval_logger.info(
|
| 429 |
-
"Passed `--trust_remote_code`, setting environment variable `HF_DATASETS_TRUST_REMOTE_CODE=true`"
|
| 430 |
-
)
|
| 431 |
-
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
|
| 432 |
-
# because it's already been determined based on the prior env var before launching our
|
| 433 |
-
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
|
| 434 |
-
import datasets
|
| 435 |
-
|
| 436 |
-
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
|
| 437 |
-
|
| 438 |
-
args.model_args = args.model_args + ",trust_remote_code=True"
|
| 439 |
-
(
|
| 440 |
-
eval_logger.info(f"Selected Tasks: {task_names}")
|
| 441 |
-
if eval_logger.getEffectiveLevel() >= logging.INFO
|
| 442 |
-
else print(f"Selected Tasks: {task_names}")
|
| 443 |
-
)
|
| 444 |
-
|
| 445 |
-
request_caching_args = request_caching_arg_to_dict(
|
| 446 |
-
cache_requests=args.cache_requests
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
results = evaluator.simple_evaluate(
|
| 450 |
-
model=args.model,
|
| 451 |
-
model_args=args.model_args,
|
| 452 |
-
tasks=task_names,
|
| 453 |
-
num_fewshot=args.num_fewshot,
|
| 454 |
-
batch_size=args.batch_size,
|
| 455 |
-
max_batch_size=args.max_batch_size,
|
| 456 |
-
device=args.device,
|
| 457 |
-
use_cache=args.use_cache,
|
| 458 |
-
limit=args.limit,
|
| 459 |
-
samples=args.samples,
|
| 460 |
-
check_integrity=args.check_integrity,
|
| 461 |
-
write_out=args.write_out,
|
| 462 |
-
log_samples=args.log_samples,
|
| 463 |
-
evaluation_tracker=evaluation_tracker,
|
| 464 |
-
system_instruction=args.system_instruction,
|
| 465 |
-
apply_chat_template=args.apply_chat_template,
|
| 466 |
-
fewshot_as_multiturn=args.fewshot_as_multiturn,
|
| 467 |
-
gen_kwargs=args.gen_kwargs,
|
| 468 |
-
task_manager=task_manager,
|
| 469 |
-
predict_only=args.predict_only,
|
| 470 |
-
random_seed=args.seed[0],
|
| 471 |
-
numpy_random_seed=args.seed[1],
|
| 472 |
-
torch_random_seed=args.seed[2],
|
| 473 |
-
fewshot_random_seed=args.seed[3],
|
| 474 |
-
confirm_run_unsafe_code=args.confirm_run_unsafe_code,
|
| 475 |
-
metadata=metadata,
|
| 476 |
-
**request_caching_args,
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
if results is not None:
|
| 480 |
-
if args.log_samples:
|
| 481 |
-
samples = results.pop("samples")
|
| 482 |
-
dumped = json.dumps(
|
| 483 |
-
results, indent=2, default=handle_non_serializable, ensure_ascii=False
|
| 484 |
-
)
|
| 485 |
-
if args.show_config:
|
| 486 |
-
print(dumped)
|
| 487 |
-
|
| 488 |
-
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
|
| 489 |
-
|
| 490 |
-
# Add W&B logging
|
| 491 |
-
if args.wandb_args:
|
| 492 |
-
try:
|
| 493 |
-
wandb_logger.post_init(results)
|
| 494 |
-
wandb_logger.log_eval_result()
|
| 495 |
-
if args.log_samples:
|
| 496 |
-
wandb_logger.log_eval_samples(samples)
|
| 497 |
-
except Exception as e:
|
| 498 |
-
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")
|
| 499 |
-
|
| 500 |
-
evaluation_tracker.save_results_aggregated(
|
| 501 |
-
results=results, samples=samples if args.log_samples else None
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
if args.log_samples:
|
| 505 |
-
for task_name, config in results["configs"].items():
|
| 506 |
-
evaluation_tracker.save_results_samples(
|
| 507 |
-
task_name=task_name, samples=samples[task_name]
|
| 508 |
-
)
|
| 509 |
-
|
| 510 |
-
if (
|
| 511 |
-
evaluation_tracker.push_results_to_hub
|
| 512 |
-
or evaluation_tracker.push_samples_to_hub
|
| 513 |
-
):
|
| 514 |
-
evaluation_tracker.recreate_metadata_card()
|
| 515 |
-
|
| 516 |
-
print(
|
| 517 |
-
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
|
| 518 |
-
f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
|
| 519 |
-
)
|
| 520 |
-
print(make_table(results))
|
| 521 |
-
if "groups" in results:
|
| 522 |
-
print(make_table(results, "groups"))
|
| 523 |
-
|
| 524 |
-
if args.wandb_args:
|
| 525 |
-
# Tear down wandb run once all the logging is done.
|
| 526 |
-
wandb_logger.run.finish()
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
if __name__ == "__main__":
|
| 530 |
-
cli_evaluate()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/filter.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
from abc import ABC, abstractmethod
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
-
from typing import Callable, Iterable, List, Union
|
| 4 |
-
|
| 5 |
-
from lm_eval.api.instance import Instance
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
class Filter(ABC):
|
| 9 |
-
"""
|
| 10 |
-
Filter classes operate on a per-task level.
|
| 11 |
-
They take all model outputs (`instance.resps` for all `task.instances`)
|
| 12 |
-
across all instances of a task, and perform operations.
|
| 13 |
-
In a single run, one can configure any number of separate filters or lists of filters.
|
| 14 |
-
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
def __init__(self, **kwargs) -> None:
|
| 18 |
-
"""
|
| 19 |
-
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
@abstractmethod
|
| 23 |
-
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
|
| 24 |
-
"""
|
| 25 |
-
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
|
| 26 |
-
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
|
| 27 |
-
if pass in [<inst.resps for instance 0>, <inst.resps for instance 1>] should return
|
| 28 |
-
[<filtered resps for instance 0>, <filtered resps for instance 1>]
|
| 29 |
-
"""
|
| 30 |
-
return resps
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@dataclass
|
| 34 |
-
class FilterEnsemble:
|
| 35 |
-
"""
|
| 36 |
-
FilterEnsemble creates a pipeline applying multiple filters.
|
| 37 |
-
Its intended usage is to stack multiple post-processing steps in order.
|
| 38 |
-
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
|
| 39 |
-
pipeline separately.
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
name: str
|
| 43 |
-
filters: List[Callable[[], Filter]]
|
| 44 |
-
|
| 45 |
-
def apply(self, instances: List[Instance]) -> None:
|
| 46 |
-
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
|
| 47 |
-
resps, docs = list(resps), list(docs)
|
| 48 |
-
|
| 49 |
-
for f in self.filters:
|
| 50 |
-
# apply filters in sequence
|
| 51 |
-
resps = f().apply(resps, docs)
|
| 52 |
-
|
| 53 |
-
# add the end results after filtering to filtered_requests of their respective source instances.
|
| 54 |
-
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
|
| 55 |
-
for inst, resp in zip(instances, resps):
|
| 56 |
-
inst.filtered_resps[self.name] = resp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/group.py
DELETED
|
@@ -1,115 +0,0 @@
|
|
| 1 |
-
import abc
|
| 2 |
-
from dataclasses import asdict, dataclass
|
| 3 |
-
from inspect import getsource
|
| 4 |
-
from typing import Any, Callable, List, Optional, Union
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@dataclass
|
| 8 |
-
class AggMetricConfig(dict):
|
| 9 |
-
metric: Optional[str] = None
|
| 10 |
-
aggregation: Optional[str] = "mean"
|
| 11 |
-
weight_by_size: Optional[str] = False
|
| 12 |
-
# list of filter names which should be incorporated into the aggregated metric.
|
| 13 |
-
filter_list: Optional[Union[str, list]] = "none"
|
| 14 |
-
|
| 15 |
-
def __post_init__(self):
|
| 16 |
-
if self.aggregation != "mean" and not callable(self.aggregation):
|
| 17 |
-
raise ValueError(
|
| 18 |
-
f"Currently, 'mean' is the only pre-defined aggregation across groups' subtasks. Got '{self.aggregation}'."
|
| 19 |
-
)
|
| 20 |
-
|
| 21 |
-
if isinstance(self.filter_list, str):
|
| 22 |
-
self.filter_list = [self.filter_list]
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
@dataclass
|
| 26 |
-
class GroupConfig(dict):
|
| 27 |
-
group: Optional[str] = None
|
| 28 |
-
group_alias: Optional[str] = None
|
| 29 |
-
task: Optional[Union[str, list]] = None
|
| 30 |
-
aggregate_metric_list: Optional[
|
| 31 |
-
Union[List[AggMetricConfig], AggMetricConfig, dict]
|
| 32 |
-
] = None
|
| 33 |
-
metadata: Optional[dict] = (
|
| 34 |
-
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
def __getitem__(self, item):
|
| 38 |
-
return getattr(self, item)
|
| 39 |
-
|
| 40 |
-
def __setitem__(self, item, value):
|
| 41 |
-
return setattr(self, item, value)
|
| 42 |
-
|
| 43 |
-
def __post_init__(self):
|
| 44 |
-
if self.aggregate_metric_list is not None:
|
| 45 |
-
if isinstance(self.aggregate_metric_list, dict):
|
| 46 |
-
self.aggregate_metric_list = [self.aggregate_metric_list]
|
| 47 |
-
|
| 48 |
-
self.aggregate_metric_list = [
|
| 49 |
-
AggMetricConfig(**item) if isinstance(item, dict) else item
|
| 50 |
-
for item in self.aggregate_metric_list
|
| 51 |
-
]
|
| 52 |
-
|
| 53 |
-
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 54 |
-
"""dumps the current config as a dictionary object, as a printable format.
|
| 55 |
-
null fields will not be printed.
|
| 56 |
-
Used for dumping results alongside full task configuration
|
| 57 |
-
|
| 58 |
-
:return: dict
|
| 59 |
-
A printable dictionary version of the TaskConfig object.
|
| 60 |
-
|
| 61 |
-
# TODO: should any default value in the TaskConfig not be printed?
|
| 62 |
-
"""
|
| 63 |
-
cfg_dict = asdict(self)
|
| 64 |
-
# remove values that are `None`
|
| 65 |
-
for k, v in list(cfg_dict.items()):
|
| 66 |
-
if callable(v):
|
| 67 |
-
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 68 |
-
return cfg_dict
|
| 69 |
-
|
| 70 |
-
def serialize_function(
|
| 71 |
-
self, value: Union[Callable, str], keep_callable=False
|
| 72 |
-
) -> Union[Callable, str]:
|
| 73 |
-
"""Serializes a given function or string.
|
| 74 |
-
|
| 75 |
-
If 'keep_callable' is True, the original callable is returned.
|
| 76 |
-
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 77 |
-
"""
|
| 78 |
-
if keep_callable:
|
| 79 |
-
return value
|
| 80 |
-
else:
|
| 81 |
-
try:
|
| 82 |
-
return getsource(value)
|
| 83 |
-
except (TypeError, OSError):
|
| 84 |
-
return str(value)
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class ConfigurableGroup(abc.ABC):
|
| 88 |
-
def __init__(
|
| 89 |
-
self,
|
| 90 |
-
config: Optional[dict] = None,
|
| 91 |
-
) -> None:
|
| 92 |
-
self._config = GroupConfig(**config)
|
| 93 |
-
|
| 94 |
-
@property
|
| 95 |
-
def group(self):
|
| 96 |
-
return self._config.group
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def group_alias(self):
|
| 100 |
-
return self._config.group_alias
|
| 101 |
-
|
| 102 |
-
@property
|
| 103 |
-
def version(self):
|
| 104 |
-
return self._config.version
|
| 105 |
-
|
| 106 |
-
@property
|
| 107 |
-
def config(self):
|
| 108 |
-
return self._config.to_dict()
|
| 109 |
-
|
| 110 |
-
@property
|
| 111 |
-
def group_name(self) -> Any:
|
| 112 |
-
return self._config.group
|
| 113 |
-
|
| 114 |
-
def __repr__(self):
|
| 115 |
-
return f"ConfigurableGroup(group={self.group},group_alias={self.group_alias})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/instance.py
DELETED
|
@@ -1,38 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass, field
|
| 2 |
-
from typing import Literal, Optional, Tuple
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
OutputType = Literal[
|
| 6 |
-
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
|
| 7 |
-
]
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@dataclass
|
| 11 |
-
class Instance:
|
| 12 |
-
request_type: OutputType
|
| 13 |
-
doc: dict
|
| 14 |
-
arguments: tuple
|
| 15 |
-
idx: int
|
| 16 |
-
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
|
| 17 |
-
default_factory=lambda: (None, None, None)
|
| 18 |
-
)
|
| 19 |
-
resps: list = field(default_factory=list)
|
| 20 |
-
filtered_resps: dict = field(default_factory=dict)
|
| 21 |
-
|
| 22 |
-
# initialized after init
|
| 23 |
-
task_name: Optional[str] = None
|
| 24 |
-
doc_id: Optional[int] = None
|
| 25 |
-
repeats: Optional[int] = None
|
| 26 |
-
|
| 27 |
-
def __post_init__(self) -> None:
|
| 28 |
-
# unpack metadata field
|
| 29 |
-
self.task_name, self.doc_id, self.repeats = self.metadata
|
| 30 |
-
|
| 31 |
-
@property
|
| 32 |
-
def args(self):
|
| 33 |
-
"""
|
| 34 |
-
Returns (string,) where `string` is the string to calculate loglikelihood over
|
| 35 |
-
"""
|
| 36 |
-
return (
|
| 37 |
-
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
|
| 38 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/metrics.py
DELETED
|
@@ -1,578 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import math
|
| 3 |
-
import random
|
| 4 |
-
import re
|
| 5 |
-
import string
|
| 6 |
-
from collections.abc import Iterable
|
| 7 |
-
from typing import List
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import sacrebleu
|
| 11 |
-
|
| 12 |
-
from lm_eval.api.registry import register_aggregation, register_metric
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
eval_logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# Register Aggregations First
|
| 19 |
-
@register_aggregation("bypass")
|
| 20 |
-
def bypass_agg(arr):
|
| 21 |
-
return 999
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
@register_aggregation("nanmean")
|
| 25 |
-
def nanmean(arr):
|
| 26 |
-
if len(arr) == 0 or all(np.isnan(arr)):
|
| 27 |
-
return np.nan
|
| 28 |
-
return np.nanmean(arr)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@register_aggregation("mean")
|
| 32 |
-
def mean(arr):
|
| 33 |
-
return sum(arr) / len(arr)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@register_aggregation("median")
|
| 37 |
-
def median(arr):
|
| 38 |
-
return arr[len(arr) // 2]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Certain metrics must be calculated across all documents in a benchmark.
|
| 42 |
-
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
|
| 43 |
-
@register_aggregation("perplexity")
|
| 44 |
-
def perplexity(items):
|
| 45 |
-
return math.exp(-mean(items))
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
@register_aggregation("weighted_perplexity")
|
| 49 |
-
def weighted_perplexity(items):
|
| 50 |
-
return math.exp(-weighted_mean(items))
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
@register_aggregation("bits_per_byte")
|
| 54 |
-
def bits_per_byte(items):
|
| 55 |
-
return -weighted_mean(items) / math.log(2)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
@register_aggregation("f1")
|
| 59 |
-
def f1_score(items):
|
| 60 |
-
from sklearn.metrics import f1_score
|
| 61 |
-
|
| 62 |
-
unzipped_list = list(zip(*items))
|
| 63 |
-
golds = unzipped_list[0]
|
| 64 |
-
preds = unzipped_list[1]
|
| 65 |
-
fscore = f1_score(golds, preds)
|
| 66 |
-
|
| 67 |
-
return np.max(fscore)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
@register_aggregation("matthews_corrcoef")
|
| 71 |
-
def matthews_corrcoef(items):
|
| 72 |
-
from sklearn.metrics import matthews_corrcoef
|
| 73 |
-
|
| 74 |
-
unzipped_list = list(zip(*items))
|
| 75 |
-
golds = unzipped_list[0]
|
| 76 |
-
preds = unzipped_list[1]
|
| 77 |
-
return matthews_corrcoef(golds, preds)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
@register_aggregation("bleu")
|
| 81 |
-
def bleu(items):
|
| 82 |
-
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
| 83 |
-
for evaluating a generated sentence to a reference sentence. It counts matching
|
| 84 |
-
n-grams in the candidate translation to n-grams in the reference text, where
|
| 85 |
-
1-gram or unigram would be each token and a bigram comparison would be each
|
| 86 |
-
word pair. The comparison is made regardless of word order
|
| 87 |
-
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
| 88 |
-
Paper: https://www.aclweb.org/anthology/P02-1040/
|
| 89 |
-
|
| 90 |
-
Higher is better
|
| 91 |
-
"""
|
| 92 |
-
refs = list(zip(*items))[0]
|
| 93 |
-
preds = list(zip(*items))[1]
|
| 94 |
-
refs, preds = _sacreformat(refs, preds)
|
| 95 |
-
return sacrebleu.corpus_bleu(preds, refs).score
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
@register_aggregation("chrf")
|
| 99 |
-
def chrf(items):
|
| 100 |
-
"""chrF++ is a tool for automatic evaluation of machine translation output
|
| 101 |
-
based on character n-gram precision and recall enhanced with word n-grams.
|
| 102 |
-
Source: https://github.com/m-popovic/chrF
|
| 103 |
-
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
| 104 |
-
|
| 105 |
-
Higher is better # TODO I think
|
| 106 |
-
"""
|
| 107 |
-
refs = list(zip(*items))[0]
|
| 108 |
-
preds = list(zip(*items))[1]
|
| 109 |
-
refs, preds = _sacreformat(refs, preds)
|
| 110 |
-
return sacrebleu.corpus_chrf(preds, refs).score
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
@register_aggregation("ter")
|
| 114 |
-
def ter(items):
|
| 115 |
-
"""Translation Error Rate is an error metric for machine translation that
|
| 116 |
-
measures the number of edits required to change a system output into one
|
| 117 |
-
of the references
|
| 118 |
-
Source: http://www.cs.umd.edu/~snover/tercom/
|
| 119 |
-
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
|
| 120 |
-
|
| 121 |
-
Lower is better
|
| 122 |
-
"""
|
| 123 |
-
refs = list(zip(*items))[0]
|
| 124 |
-
preds = list(zip(*items))[1]
|
| 125 |
-
refs, preds = _sacreformat(refs, preds)
|
| 126 |
-
return sacrebleu.corpus_ter(preds, refs).score
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
@register_aggregation("brier_score")
|
| 130 |
-
def brier_score(items): # This is a passthrough function
|
| 131 |
-
gold, predictions = list(zip(*items))
|
| 132 |
-
bs, num_class = np.array(predictions).shape
|
| 133 |
-
|
| 134 |
-
gold = list(gold)
|
| 135 |
-
gold_one_hot = np.eye(num_class)[gold]
|
| 136 |
-
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
@register_metric(
|
| 140 |
-
metric="brier_score",
|
| 141 |
-
higher_is_better=False,
|
| 142 |
-
output_type=["multiple_choice"],
|
| 143 |
-
aggregation="brier_score",
|
| 144 |
-
)
|
| 145 |
-
def brier_score_fn(items): # This is a passthrough function
|
| 146 |
-
return items
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
@register_metric(
|
| 150 |
-
metric="acc",
|
| 151 |
-
higher_is_better=True,
|
| 152 |
-
output_type=["loglikelihood", "multiple_choice"],
|
| 153 |
-
aggregation="mean",
|
| 154 |
-
)
|
| 155 |
-
def acc_fn(items): # This is a passthrough function
|
| 156 |
-
return items
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
@register_metric(
|
| 160 |
-
metric="acc_norm",
|
| 161 |
-
higher_is_better=True,
|
| 162 |
-
output_type=["loglikelihood", "multiple_choice"],
|
| 163 |
-
aggregation="mean",
|
| 164 |
-
)
|
| 165 |
-
def acc_norm_fn(items): # This is a passthrough function
|
| 166 |
-
return items
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
@register_metric(
|
| 170 |
-
metric="acc_mutual_info",
|
| 171 |
-
higher_is_better=True,
|
| 172 |
-
output_type="multiple_choice",
|
| 173 |
-
aggregation="mean",
|
| 174 |
-
)
|
| 175 |
-
def acc_mutual_info_fn(items): # This is a passthrough function
|
| 176 |
-
return items
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
### the code used in the `exact_match_hf_evaluate` function is ported from
|
| 180 |
-
### https://github.com/huggingface/evaluate/blob/main/metrics/exact_match/exact_match.py
|
| 181 |
-
### which is under the apache license.
|
| 182 |
-
|
| 183 |
-
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
| 184 |
-
|
| 185 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 186 |
-
# you may not use this file except in compliance with the License.
|
| 187 |
-
# You may obtain a copy of the License at
|
| 188 |
-
|
| 189 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 193 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 194 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 195 |
-
# See the License for the specific language governing permissions and
|
| 196 |
-
# limitations under the License.
|
| 197 |
-
def exact_match_hf_evaluate(
|
| 198 |
-
predictions,
|
| 199 |
-
references,
|
| 200 |
-
regexes_to_ignore=None,
|
| 201 |
-
ignore_case=False,
|
| 202 |
-
ignore_punctuation=False,
|
| 203 |
-
ignore_numbers=False,
|
| 204 |
-
):
|
| 205 |
-
if regexes_to_ignore is not None:
|
| 206 |
-
for s in regexes_to_ignore:
|
| 207 |
-
predictions = np.array([re.sub(s, "", x) for x in predictions])
|
| 208 |
-
references = np.array([re.sub(s, "", x) for x in references])
|
| 209 |
-
else:
|
| 210 |
-
predictions = np.asarray(predictions)
|
| 211 |
-
references = np.asarray(references)
|
| 212 |
-
|
| 213 |
-
if ignore_case:
|
| 214 |
-
predictions = np.char.lower(predictions)
|
| 215 |
-
references = np.char.lower(references)
|
| 216 |
-
|
| 217 |
-
if ignore_punctuation:
|
| 218 |
-
repl_table = string.punctuation.maketrans("", "", string.punctuation)
|
| 219 |
-
predictions = np.char.translate(predictions, table=repl_table)
|
| 220 |
-
references = np.char.translate(references, table=repl_table)
|
| 221 |
-
|
| 222 |
-
if ignore_numbers:
|
| 223 |
-
repl_table = string.digits.maketrans("", "", string.digits)
|
| 224 |
-
predictions = np.char.translate(predictions, table=repl_table)
|
| 225 |
-
references = np.char.translate(references, table=repl_table)
|
| 226 |
-
|
| 227 |
-
score_list = predictions == references
|
| 228 |
-
|
| 229 |
-
return {"exact_match": np.mean(score_list)}
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
###
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
@register_metric(
|
| 236 |
-
metric="exact_match",
|
| 237 |
-
higher_is_better=True,
|
| 238 |
-
output_type="generate_until",
|
| 239 |
-
aggregation="mean",
|
| 240 |
-
)
|
| 241 |
-
def exact_match_fn(**kwargs):
|
| 242 |
-
return exact_match_hf_evaluate(**kwargs)
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
@register_metric(
|
| 246 |
-
metric="perplexity",
|
| 247 |
-
higher_is_better=False,
|
| 248 |
-
output_type="loglikelihood",
|
| 249 |
-
aggregation="perplexity",
|
| 250 |
-
)
|
| 251 |
-
def perplexity_fn(items): # This is a passthrough function
|
| 252 |
-
return items
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
@register_metric(
|
| 256 |
-
metric="word_perplexity",
|
| 257 |
-
higher_is_better=False,
|
| 258 |
-
output_type="loglikelihood_rolling",
|
| 259 |
-
aggregation="weighted_perplexity",
|
| 260 |
-
)
|
| 261 |
-
def word_perplexity_fn(items): # This is a passthrough function
|
| 262 |
-
return items
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
@register_metric(
|
| 266 |
-
metric="byte_perplexity",
|
| 267 |
-
higher_is_better=False,
|
| 268 |
-
output_type="loglikelihood_rolling",
|
| 269 |
-
aggregation="weighted_perplexity",
|
| 270 |
-
)
|
| 271 |
-
def byte_perplexity_fn(items): # This is a passthrough function
|
| 272 |
-
return items
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
@register_metric(
|
| 276 |
-
metric="bits_per_byte",
|
| 277 |
-
higher_is_better=False,
|
| 278 |
-
output_type="loglikelihood_rolling",
|
| 279 |
-
aggregation="bits_per_byte",
|
| 280 |
-
)
|
| 281 |
-
def bits_per_byte_fn(items): # This is a passthrough function
|
| 282 |
-
return items
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
def pop_stddev(arr):
|
| 286 |
-
mu = mean(arr)
|
| 287 |
-
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
def sample_stddev(arr):
|
| 291 |
-
mu = mean(arr)
|
| 292 |
-
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def mean_stderr(arr):
|
| 296 |
-
return sample_stddev(arr) / math.sqrt(len(arr))
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
@register_metric(
|
| 300 |
-
metric="bypass",
|
| 301 |
-
higher_is_better=True,
|
| 302 |
-
output_type=["loglikelihood", "multiple_choice", "generate_until"],
|
| 303 |
-
aggregation="bypass",
|
| 304 |
-
)
|
| 305 |
-
def bypass(items):
|
| 306 |
-
return None
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
@register_metric(
|
| 310 |
-
metric="mcc",
|
| 311 |
-
higher_is_better=True,
|
| 312 |
-
output_type="multiple_choice",
|
| 313 |
-
aggregation="matthews_corrcoef",
|
| 314 |
-
)
|
| 315 |
-
def mcc_fn(items): # This is a passthrough function
|
| 316 |
-
return items
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
@register_metric(
|
| 320 |
-
metric="f1",
|
| 321 |
-
higher_is_better=True,
|
| 322 |
-
output_type="multiple_choice",
|
| 323 |
-
aggregation="f1",
|
| 324 |
-
)
|
| 325 |
-
def f1_fn(items): # This is a passthrough function
|
| 326 |
-
return items
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
@register_metric(
|
| 330 |
-
metric="bleu",
|
| 331 |
-
higher_is_better=True,
|
| 332 |
-
output_type="generate_until",
|
| 333 |
-
aggregation="bleu",
|
| 334 |
-
)
|
| 335 |
-
def bleu_fn(items): # This is a passthrough function
|
| 336 |
-
return items
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
@register_metric(
|
| 340 |
-
metric="chrf",
|
| 341 |
-
higher_is_better=True,
|
| 342 |
-
output_type="generate_until",
|
| 343 |
-
aggregation="chrf",
|
| 344 |
-
)
|
| 345 |
-
def chrf_fn(items): # This is a passthrough function
|
| 346 |
-
return items
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
@register_metric(
|
| 350 |
-
metric="ter",
|
| 351 |
-
higher_is_better=True,
|
| 352 |
-
output_type="generate_until",
|
| 353 |
-
aggregation="ter",
|
| 354 |
-
)
|
| 355 |
-
def ter_fn(items): # This is a passthrough function
|
| 356 |
-
return items
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
@register_metric(
|
| 360 |
-
metric="acc_all",
|
| 361 |
-
higher_is_better=True,
|
| 362 |
-
output_type="loglikelihood",
|
| 363 |
-
aggregation="mean",
|
| 364 |
-
)
|
| 365 |
-
def acc_all(items):
|
| 366 |
-
# Only count as correct if all answers are labeled correctly for each question
|
| 367 |
-
question_scoring_dict = {}
|
| 368 |
-
preds = list(zip(*items))[0]
|
| 369 |
-
docs = list(zip(*items))[1]
|
| 370 |
-
|
| 371 |
-
for doc, pred in zip(docs, preds):
|
| 372 |
-
paragraph_id = doc["idx"]["paragraph"]
|
| 373 |
-
question_id = doc["idx"]["question"]
|
| 374 |
-
if (paragraph_id, question_id) not in question_scoring_dict:
|
| 375 |
-
question_scoring_dict[(paragraph_id, question_id)] = []
|
| 376 |
-
|
| 377 |
-
gold_label = doc["label"] == 1
|
| 378 |
-
|
| 379 |
-
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
|
| 380 |
-
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
|
| 381 |
-
return acc
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
def acc_all_stderr(items):
|
| 385 |
-
# Only count as correct if all answers are labeled correctly for each question
|
| 386 |
-
question_scoring_dict = {}
|
| 387 |
-
preds = list(zip(*items))[0]
|
| 388 |
-
docs = list(zip(*items))[1]
|
| 389 |
-
|
| 390 |
-
for doc, pred in zip(docs, preds):
|
| 391 |
-
question_id = doc["idx"]["question"]
|
| 392 |
-
if question_id not in question_scoring_dict:
|
| 393 |
-
question_scoring_dict[question_id] = []
|
| 394 |
-
|
| 395 |
-
gold_label = doc["label"] == 1
|
| 396 |
-
question_scoring_dict[question_id].append(gold_label == pred)
|
| 397 |
-
|
| 398 |
-
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
|
| 399 |
-
return acc
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
| 403 |
-
"""Compute max metric between prediction and each ground truth."""
|
| 404 |
-
scores_for_ground_truths = []
|
| 405 |
-
for ground_truth in ground_truths:
|
| 406 |
-
score = metric_fn(prediction, ground_truth)
|
| 407 |
-
scores_for_ground_truths.append(score)
|
| 408 |
-
return max(scores_for_ground_truths)
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
def weighted_mean(items):
|
| 412 |
-
a, b = zip(*items)
|
| 413 |
-
return sum(a) / sum(b)
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
def is_non_str_iterable(obj):
|
| 417 |
-
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
def _sacreformat(refs, preds):
|
| 421 |
-
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
|
| 422 |
-
# Sacrebleu expects (List[str], List[List[str])
|
| 423 |
-
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
|
| 424 |
-
|
| 425 |
-
# Note [ref1_stream] is the first reference for each pred.
|
| 426 |
-
# So lists are size N and (M, N) for N preds and M possible refs for each pred
|
| 427 |
-
# This is a different order of dimensions that I would expect
|
| 428 |
-
|
| 429 |
-
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
|
| 430 |
-
# Must become List[List[str]] with the inner list corresponding to preds
|
| 431 |
-
if not is_non_str_iterable(refs):
|
| 432 |
-
refs = list(refs)
|
| 433 |
-
if not is_non_str_iterable(refs[0]):
|
| 434 |
-
refs = [[ref] for ref in refs]
|
| 435 |
-
refs = list(zip(*refs))
|
| 436 |
-
# Note the number of refs in each ref list much match the number of preds
|
| 437 |
-
|
| 438 |
-
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
|
| 439 |
-
if not is_non_str_iterable(preds):
|
| 440 |
-
preds = list(preds)
|
| 441 |
-
if is_non_str_iterable(preds[0]):
|
| 442 |
-
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
|
| 443 |
-
preds = [pred[0] for pred in preds]
|
| 444 |
-
|
| 445 |
-
return refs, preds
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
# stderr stuff
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
class _bootstrap_internal:
|
| 452 |
-
def __init__(self, f, n) -> None:
|
| 453 |
-
self.f = f
|
| 454 |
-
self.n = n
|
| 455 |
-
|
| 456 |
-
def __call__(self, v):
|
| 457 |
-
i, xs = v
|
| 458 |
-
rnd = random.Random()
|
| 459 |
-
rnd.seed(i)
|
| 460 |
-
res = []
|
| 461 |
-
for _ in range(self.n):
|
| 462 |
-
res.append(self.f(rnd.choices(xs, k=len(xs))))
|
| 463 |
-
return res
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
def bootstrap_stderr(f, xs, iters):
|
| 467 |
-
import multiprocessing as mp
|
| 468 |
-
|
| 469 |
-
pool = mp.Pool(mp.cpu_count())
|
| 470 |
-
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
|
| 471 |
-
# equivalent to stderr calculated without Bessel's correction in the stddev.
|
| 472 |
-
# Unfortunately, I haven't been able to figure out what the right correction is
|
| 473 |
-
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
|
| 474 |
-
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
|
| 475 |
-
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
|
| 476 |
-
res = []
|
| 477 |
-
chunk_size = min(1000, iters)
|
| 478 |
-
from tqdm import tqdm
|
| 479 |
-
|
| 480 |
-
print("bootstrapping for stddev:", f.__name__)
|
| 481 |
-
for bootstrap in tqdm(
|
| 482 |
-
pool.imap(
|
| 483 |
-
_bootstrap_internal(f, chunk_size),
|
| 484 |
-
[(i, xs) for i in range(iters // chunk_size)],
|
| 485 |
-
),
|
| 486 |
-
total=iters // chunk_size,
|
| 487 |
-
):
|
| 488 |
-
# sample w replacement
|
| 489 |
-
res.extend(bootstrap)
|
| 490 |
-
|
| 491 |
-
pool.close()
|
| 492 |
-
return sample_stddev(res)
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
def stderr_for_metric(metric, bootstrap_iters: int):
|
| 496 |
-
if bootstrap_iters <= 0:
|
| 497 |
-
# return no function (don't compute stderr) if bootstrap iters = 0
|
| 498 |
-
return None
|
| 499 |
-
|
| 500 |
-
bootstrappable = [
|
| 501 |
-
median,
|
| 502 |
-
matthews_corrcoef,
|
| 503 |
-
f1_score,
|
| 504 |
-
perplexity,
|
| 505 |
-
bleu,
|
| 506 |
-
chrf,
|
| 507 |
-
ter,
|
| 508 |
-
nanmean,
|
| 509 |
-
]
|
| 510 |
-
|
| 511 |
-
if metric in bootstrappable:
|
| 512 |
-
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
|
| 513 |
-
|
| 514 |
-
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
|
| 515 |
-
|
| 516 |
-
return stderr.get(metric, None)
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
|
| 520 |
-
# Used to aggregate bootstrapped stderrs across subtasks in a group,
|
| 521 |
-
# when we are weighting by the size of each subtask.
|
| 522 |
-
#
|
| 523 |
-
|
| 524 |
-
assert len(stderrs) == len(sizes)
|
| 525 |
-
|
| 526 |
-
# formula source: https://en.wikipedia.org/wiki/Pooled_variance
|
| 527 |
-
# and: https://stats.stackexchange.com/a/4841331
|
| 528 |
-
# this empirically seems to match running `stderr_for_metric` on all instances
|
| 529 |
-
# from the subtasks concatenated with each other.
|
| 530 |
-
pooled_sample_var = (
|
| 531 |
-
sum([(size - 1) * stderr**2 * size for size, stderr in zip(sizes, stderrs)])
|
| 532 |
-
) / (sum(sizes) - len(sizes))
|
| 533 |
-
|
| 534 |
-
return np.sqrt(pooled_sample_var / sum(sizes))
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None):
|
| 538 |
-
assert metrics is not None, (
|
| 539 |
-
"Need to pass a list of each subtask's metric for this stderr aggregation"
|
| 540 |
-
)
|
| 541 |
-
assert len(stderrs) == len(sizes) and len(sizes) == len(metrics)
|
| 542 |
-
|
| 543 |
-
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1390 for more documentation.
|
| 544 |
-
# This formula depends on sample means.
|
| 545 |
-
# removed because it seems to give erroneously huge stderrs for groupings of tasks
|
| 546 |
-
# and does not seem to match up with bootstrap-calculated stderrs for groups.
|
| 547 |
-
|
| 548 |
-
### don't use this unless a statistician has told you it's the right thing to do ###
|
| 549 |
-
|
| 550 |
-
# accumulators: we'll aggregate pairwise N - 1 times
|
| 551 |
-
variance = stderrs[0] ** 2
|
| 552 |
-
curr_size = sizes[0]
|
| 553 |
-
curr_score = metrics[0]
|
| 554 |
-
|
| 555 |
-
for stderr, size, score in zip(stderrs[1:], sizes[1:], metrics[1:]):
|
| 556 |
-
curr_score = ((curr_score * curr_size) + (score * size)) / (
|
| 557 |
-
curr_size + size
|
| 558 |
-
) # NOTE: this assumes our aggregation fn is "mean"
|
| 559 |
-
|
| 560 |
-
variance = ((curr_size - 1) * variance + (size - 1) * (stderr**2)) / (
|
| 561 |
-
curr_size + size - 1
|
| 562 |
-
) + curr_size * size / ((curr_size + size) * (curr_size + size - 1)) * (
|
| 563 |
-
curr_score - score
|
| 564 |
-
) ** 2
|
| 565 |
-
|
| 566 |
-
return np.sqrt(variance)
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
|
| 570 |
-
# A helper function that is used to aggregate
|
| 571 |
-
# subtask scores cross-task.
|
| 572 |
-
# TODO: does not hold for non-mean aggregations
|
| 573 |
-
if not weight_by_size:
|
| 574 |
-
sizes = [1] * len(sizes)
|
| 575 |
-
|
| 576 |
-
assert len(metrics) == len(sizes)
|
| 577 |
-
|
| 578 |
-
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/model.py
DELETED
|
@@ -1,493 +0,0 @@
|
|
| 1 |
-
import abc
|
| 2 |
-
import hashlib
|
| 3 |
-
import json
|
| 4 |
-
import logging
|
| 5 |
-
import os
|
| 6 |
-
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
|
| 7 |
-
|
| 8 |
-
import transformers
|
| 9 |
-
from sqlitedict import SqliteDict
|
| 10 |
-
from tqdm import tqdm
|
| 11 |
-
|
| 12 |
-
from lm_eval import utils
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
eval_logger = logging.getLogger(__name__)
|
| 16 |
-
|
| 17 |
-
T = TypeVar("T", bound="LM")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
class LM(abc.ABC):
|
| 21 |
-
def __init__(self) -> None:
|
| 22 |
-
"""Defines the interface that should be implemented by all LM subclasses.
|
| 23 |
-
LMs are assumed to take text (strings) as input and yield strings as output
|
| 24 |
-
(inputs/outputs should be tokenization-agnostic.)
|
| 25 |
-
|
| 26 |
-
"""
|
| 27 |
-
# set rank and world size to a single process, by default.
|
| 28 |
-
self._rank = 0
|
| 29 |
-
self._world_size = 1
|
| 30 |
-
self.cache_hook = CacheHook(None)
|
| 31 |
-
|
| 32 |
-
@abc.abstractmethod
|
| 33 |
-
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
|
| 34 |
-
"""Compute log-likelihood of generating a continuation from a context.
|
| 35 |
-
Downstream tasks should attempt to use loglikelihood instead of other
|
| 36 |
-
LM calls whenever possible.
|
| 37 |
-
|
| 38 |
-
:param requests: list[Instance]
|
| 39 |
-
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
|
| 40 |
-
`context: str`
|
| 41 |
-
Context string. Implementations of LM must be able to handle an
|
| 42 |
-
empty context string.
|
| 43 |
-
`continuation: str`
|
| 44 |
-
The continuation over which log likelihood will be calculated. If
|
| 45 |
-
there is a word boundary, the space should be in the continuation.
|
| 46 |
-
For example, context="hello" continuation=" world" is correct.
|
| 47 |
-
|
| 48 |
-
:return: list[tuple[float, bool]]
|
| 49 |
-
A list of pairs (logprob, isgreedy)
|
| 50 |
-
`logprob: float`
|
| 51 |
-
The log probability of `continuation`.
|
| 52 |
-
`isgreedy`:
|
| 53 |
-
Whether `continuation` would be generated by greedy sampling from `context`.
|
| 54 |
-
"""
|
| 55 |
-
pass
|
| 56 |
-
|
| 57 |
-
@abc.abstractmethod
|
| 58 |
-
def loglikelihood_rolling(self, requests) -> List[float]:
|
| 59 |
-
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
| 60 |
-
- We will use the full max context length of the model.
|
| 61 |
-
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
| 62 |
-
the max context length.
|
| 63 |
-
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
| 64 |
-
which may simply concatenate multiple documents together.
|
| 65 |
-
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
| 66 |
-
multiple chunks, the last input will still a full-sized context.
|
| 67 |
-
Example:
|
| 68 |
-
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
| 69 |
-
Prefix: BOS/EOS
|
| 70 |
-
Max context length: 4
|
| 71 |
-
Resulting input/prediction pairs:
|
| 72 |
-
|
| 73 |
-
INPUT: BOS 0 1 2
|
| 74 |
-
PRED: 0 1 2 3
|
| 75 |
-
|
| 76 |
-
INPUT: 3 4 5 6
|
| 77 |
-
PRED: 4 5 6 7
|
| 78 |
-
|
| 79 |
-
INPUT: 5 6 7 8
|
| 80 |
-
PRED: 8 9
|
| 81 |
-
|
| 82 |
-
Observe that:
|
| 83 |
-
1. Each token is predicted exactly once
|
| 84 |
-
2. For the last pair, we provide the full context, but only score the last two tokens
|
| 85 |
-
|
| 86 |
-
:param requests: list[Instance]
|
| 87 |
-
A list of Instance objects with property `args` which returns a tuple (context,).
|
| 88 |
-
string: str
|
| 89 |
-
String for which we are computing overall loglikelihood
|
| 90 |
-
:return: list[tuple[float]]
|
| 91 |
-
A list of tuples (logprob,)
|
| 92 |
-
logprob: float
|
| 93 |
-
The log probability of `context` conditioned on the BOS/EOS token.
|
| 94 |
-
Can also be overridden for custom cases by `prefix_token_id`.
|
| 95 |
-
"""
|
| 96 |
-
pass
|
| 97 |
-
|
| 98 |
-
# TODO: Add an optional max length
|
| 99 |
-
@abc.abstractmethod
|
| 100 |
-
def generate_until(self, requests) -> List[str]:
|
| 101 |
-
"""Generate greedily until a stopping sequence
|
| 102 |
-
|
| 103 |
-
:param requests: list[Instance]
|
| 104 |
-
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
|
| 105 |
-
context: str
|
| 106 |
-
Context string
|
| 107 |
-
gen_kwargs: dict
|
| 108 |
-
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
|
| 109 |
-
:return: list[str]
|
| 110 |
-
A list of model generated continuations.
|
| 111 |
-
continuation: str
|
| 112 |
-
The generated continuation.
|
| 113 |
-
"""
|
| 114 |
-
pass
|
| 115 |
-
|
| 116 |
-
def apply_chat_template(
|
| 117 |
-
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
|
| 118 |
-
) -> str:
|
| 119 |
-
"""
|
| 120 |
-
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
|
| 121 |
-
|
| 122 |
-
:param chat_history: list[dict[str, str]]
|
| 123 |
-
A list of dictionaries with keys 'role' and 'content'.
|
| 124 |
-
Values are strings representing the role name and the content of the message, respectively.
|
| 125 |
-
:param add_generation_prompt: bool
|
| 126 |
-
Whether to append an assistant gen prefix (for e.g. <|assistant|>) to the assistant messages in the chat history. False if prefilling an assistant message.
|
| 127 |
-
:return: str
|
| 128 |
-
A string representing the chat history in a format that can be used as input to the LM.
|
| 129 |
-
"""
|
| 130 |
-
raise NotImplementedError(
|
| 131 |
-
"To use this model with chat templates, please implement the 'apply_chat_template' method for your model type."
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
@classmethod
|
| 135 |
-
def create_from_arg_string(
|
| 136 |
-
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
|
| 137 |
-
) -> T:
|
| 138 |
-
"""
|
| 139 |
-
Creates an instance of the LM class using the given argument string and additional config.
|
| 140 |
-
|
| 141 |
-
Parameters:
|
| 142 |
-
- arg_string: A string containing arguments in the format key1=value1,key2=value2.
|
| 143 |
-
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 144 |
-
|
| 145 |
-
Returns:
|
| 146 |
-
- Instance of the LM class.
|
| 147 |
-
"""
|
| 148 |
-
additional_config = {} if additional_config is None else additional_config
|
| 149 |
-
args = utils.simple_parse_args_string(arg_string)
|
| 150 |
-
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 151 |
-
return cls(**args, **args2)
|
| 152 |
-
|
| 153 |
-
@classmethod
|
| 154 |
-
def create_from_arg_obj(
|
| 155 |
-
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
|
| 156 |
-
) -> T:
|
| 157 |
-
"""
|
| 158 |
-
Creates an instance of the LM class using the given arg_obj
|
| 159 |
-
|
| 160 |
-
Parameters:
|
| 161 |
-
- arg_obj: A dict containing arguments in the format key1=value1,key2=value2.
|
| 162 |
-
- additional_config: Optional dictionary containing additional configuration parameters.
|
| 163 |
-
|
| 164 |
-
Returns:
|
| 165 |
-
- Instance of the LM class.
|
| 166 |
-
"""
|
| 167 |
-
|
| 168 |
-
additional_config = {} if additional_config is None else additional_config
|
| 169 |
-
additional_config = {
|
| 170 |
-
k: v for k, v in additional_config.items() if v is not None
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
-
return cls(**arg_dict, **additional_config)
|
| 174 |
-
|
| 175 |
-
@property
|
| 176 |
-
def rank(self):
|
| 177 |
-
# used in the case of parallelism. Hardcoded to
|
| 178 |
-
# ensure no errors arise using API models which do
|
| 179 |
-
# not support multi-device parallelism nor expect it.
|
| 180 |
-
return self._rank
|
| 181 |
-
|
| 182 |
-
@property
|
| 183 |
-
def world_size(self):
|
| 184 |
-
# used in the case of parallelism. Hardcoded to
|
| 185 |
-
# ensure no errors arise using API models which do
|
| 186 |
-
# not support multi-device parallelism nor expect it.
|
| 187 |
-
return self._world_size
|
| 188 |
-
|
| 189 |
-
@property
|
| 190 |
-
def tokenizer_name(self) -> str:
|
| 191 |
-
"""Must be defined for LM subclasses which implement Chat Templating.
|
| 192 |
-
Should return the name of the tokenizer or chat template used.
|
| 193 |
-
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
| 194 |
-
"""
|
| 195 |
-
raise NotImplementedError(
|
| 196 |
-
"To use this model with chat templates, please implement the 'tokenizer_name' property."
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 200 |
-
"""Returns the chat template structure for user/assistant messages if a template is provided.
|
| 201 |
-
This method is intended to be overridden in a subclass to define a specific chat template format.
|
| 202 |
-
For models that do not support chat templates, this method returns None by default.
|
| 203 |
-
"""
|
| 204 |
-
|
| 205 |
-
return ""
|
| 206 |
-
|
| 207 |
-
def set_cache_hook(self, cache_hook) -> None:
|
| 208 |
-
self.cache_hook = cache_hook
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
### SQLite-based caching of LM responses
|
| 212 |
-
def hash_args(attr, args):
|
| 213 |
-
dat = json.dumps([attr] + list(args))
|
| 214 |
-
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
class CacheHook:
|
| 218 |
-
def __init__(self, cachinglm) -> None:
|
| 219 |
-
if cachinglm is None:
|
| 220 |
-
self.dbdict = None
|
| 221 |
-
return
|
| 222 |
-
|
| 223 |
-
self.dbdict = cachinglm.dbdict
|
| 224 |
-
|
| 225 |
-
def add_partial(self, attr, req, res) -> None:
|
| 226 |
-
if self.dbdict is None:
|
| 227 |
-
return
|
| 228 |
-
hsh = hash_args(attr, req)
|
| 229 |
-
self.dbdict[hsh] = res
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
class CachingLM:
|
| 233 |
-
def __init__(self, lm, cache_db) -> None:
|
| 234 |
-
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
|
| 235 |
-
|
| 236 |
-
:param lm: LM
|
| 237 |
-
Underlying LM
|
| 238 |
-
:param cache_db: str
|
| 239 |
-
Path to cache db
|
| 240 |
-
"""
|
| 241 |
-
self.lm = lm
|
| 242 |
-
self.cache_db = cache_db
|
| 243 |
-
if os.path.dirname(cache_db):
|
| 244 |
-
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
| 245 |
-
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
| 246 |
-
|
| 247 |
-
# add hook to lm
|
| 248 |
-
lm.set_cache_hook(self.get_cache_hook())
|
| 249 |
-
|
| 250 |
-
def __getattr__(self, attr: str):
|
| 251 |
-
lm_attr = getattr(self.lm, attr)
|
| 252 |
-
if attr not in ["loglikelihood", "loglikelihood_rolling", "generate_until"]:
|
| 253 |
-
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
|
| 254 |
-
return lm_attr
|
| 255 |
-
|
| 256 |
-
def fn(requests):
|
| 257 |
-
res = []
|
| 258 |
-
remaining_reqs = []
|
| 259 |
-
warned = False
|
| 260 |
-
# figure out which ones are cached and which ones are new
|
| 261 |
-
eval_logger.info(
|
| 262 |
-
f"Loading '{attr}' responses from cache '{self.cache_db}' where possible..."
|
| 263 |
-
)
|
| 264 |
-
for req in tqdm(requests, desc="Checking cached requests"):
|
| 265 |
-
hsh = hash_args(attr, req.args)
|
| 266 |
-
if attr == "generate_until" and req.args[1].get("do_sample", False):
|
| 267 |
-
# when we are doing non-greedy generation, don't use the cache
|
| 268 |
-
# (else every "randomly sampled" generation would be identical for repeats > 1).
|
| 269 |
-
if not warned:
|
| 270 |
-
eval_logger.warning(
|
| 271 |
-
f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
|
| 272 |
-
)
|
| 273 |
-
warned = True
|
| 274 |
-
res.append(None)
|
| 275 |
-
remaining_reqs.append(req)
|
| 276 |
-
elif hsh in self.dbdict:
|
| 277 |
-
ob = self.dbdict[hsh]
|
| 278 |
-
|
| 279 |
-
assert ob is not None
|
| 280 |
-
|
| 281 |
-
res.append(ob)
|
| 282 |
-
else:
|
| 283 |
-
res.append(None)
|
| 284 |
-
remaining_reqs.append(req)
|
| 285 |
-
eval_logger.info(
|
| 286 |
-
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
|
| 287 |
-
)
|
| 288 |
-
if remaining_reqs:
|
| 289 |
-
# actually run the LM on the requests that do not have cached results
|
| 290 |
-
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
| 291 |
-
else:
|
| 292 |
-
rem_res = []
|
| 293 |
-
|
| 294 |
-
# stick the new ones back into the list and also cache any of the new ones
|
| 295 |
-
resptr = 0
|
| 296 |
-
for req, r in zip(remaining_reqs, rem_res):
|
| 297 |
-
while res[resptr] is not None:
|
| 298 |
-
resptr += 1
|
| 299 |
-
|
| 300 |
-
res[resptr] = r
|
| 301 |
-
|
| 302 |
-
# caching
|
| 303 |
-
hsh = hash_args(attr, req.args)
|
| 304 |
-
self.dbdict[hsh] = r
|
| 305 |
-
self.dbdict.commit()
|
| 306 |
-
|
| 307 |
-
return res
|
| 308 |
-
|
| 309 |
-
return fn
|
| 310 |
-
|
| 311 |
-
def get_cache_hook(self):
|
| 312 |
-
return CacheHook(self)
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
class TemplateLM(LM):
|
| 316 |
-
"""
|
| 317 |
-
A class acting as intermediary between the LM base class
|
| 318 |
-
and boilerplate often included in other LM subclasses.
|
| 319 |
-
"""
|
| 320 |
-
|
| 321 |
-
tokenizer = None
|
| 322 |
-
|
| 323 |
-
@property
|
| 324 |
-
@abc.abstractmethod
|
| 325 |
-
def eot_token_id(self):
|
| 326 |
-
pass
|
| 327 |
-
|
| 328 |
-
@property
|
| 329 |
-
def prefix_token_id(self):
|
| 330 |
-
# it is used as prefix for loglikelihood
|
| 331 |
-
return self.eot_token_id
|
| 332 |
-
|
| 333 |
-
@abc.abstractmethod
|
| 334 |
-
def tok_encode(self, string: str, **kwargs) -> List[int]:
|
| 335 |
-
"""
|
| 336 |
-
Tokenize a string using the model's tokenizer and return a list of token IDs.
|
| 337 |
-
"""
|
| 338 |
-
pass
|
| 339 |
-
|
| 340 |
-
@abc.abstractmethod
|
| 341 |
-
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 342 |
-
pass
|
| 343 |
-
|
| 344 |
-
def _encode_pair(
|
| 345 |
-
self, context: str, continuation: str
|
| 346 |
-
) -> Tuple[List[int], List[int]]:
|
| 347 |
-
n_spaces = len(context) - len(context.rstrip())
|
| 348 |
-
if n_spaces > 0:
|
| 349 |
-
continuation = context[-n_spaces:] + continuation
|
| 350 |
-
context = context[:-n_spaces]
|
| 351 |
-
|
| 352 |
-
model_class = getattr(self, "AUTO_MODEL_CLASS", None)
|
| 353 |
-
|
| 354 |
-
if model_class == transformers.AutoModelForSeq2SeqLM:
|
| 355 |
-
context_enc = self.tok_encode(context)
|
| 356 |
-
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
|
| 357 |
-
else:
|
| 358 |
-
whole_enc = self.tok_encode(context + continuation)
|
| 359 |
-
context_enc = self.tok_encode(context)
|
| 360 |
-
|
| 361 |
-
context_enc_len = len(context_enc)
|
| 362 |
-
continuation_enc = whole_enc[context_enc_len:]
|
| 363 |
-
|
| 364 |
-
return context_enc, continuation_enc
|
| 365 |
-
|
| 366 |
-
def loglikelihood(
|
| 367 |
-
self, requests, disable_tqdm: bool = False
|
| 368 |
-
) -> List[Tuple[float, bool]]:
|
| 369 |
-
new_reqs = []
|
| 370 |
-
for context, continuation in [req.args for req in requests]:
|
| 371 |
-
if context == "":
|
| 372 |
-
# BOS or EOS as context
|
| 373 |
-
context_enc, continuation_enc = (
|
| 374 |
-
[self.prefix_token_id],
|
| 375 |
-
self.tok_encode(continuation),
|
| 376 |
-
)
|
| 377 |
-
else:
|
| 378 |
-
context_enc, continuation_enc = self._encode_pair(context, continuation)
|
| 379 |
-
|
| 380 |
-
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
| 381 |
-
|
| 382 |
-
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
|
| 383 |
-
|
| 384 |
-
@abc.abstractmethod
|
| 385 |
-
def loglikelihood_rolling(
|
| 386 |
-
self, requests, disable_tqdm: bool = False
|
| 387 |
-
) -> List[float]:
|
| 388 |
-
pass
|
| 389 |
-
|
| 390 |
-
@abc.abstractmethod
|
| 391 |
-
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
| 392 |
-
pass
|
| 393 |
-
|
| 394 |
-
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 395 |
-
"""
|
| 396 |
-
Set and get the appropriate chat template for the model.
|
| 397 |
-
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
|
| 398 |
-
|
| 399 |
-
The template selection logic is adapted from the Transformers library's `apply_chat_template`
|
| 400 |
-
method in the Tokenizer class. The original implementation can be found at:
|
| 401 |
-
https://github.com/huggingface/transformers/blob/fc35907f95459d7a6c5281dfadd680b6f7b620e3/src/transformers/tokenization_utils_base.py#L1687
|
| 402 |
-
|
| 403 |
-
This method ensures that the right template is chosen based on the following:
|
| 404 |
-
0. If the model has no 'tokenizer' attribute: assumes that there is only a single possible chat template, handled on the model provider side internally. Returns the empty string.
|
| 405 |
-
1. If the model's tokenizer has multiple templates:
|
| 406 |
-
a. Use the specified template if it exists in the dictionary.
|
| 407 |
-
b. Use the default template from the list if no specific template is provided.
|
| 408 |
-
c. Raise an error if no default template exists and no specific template is provided.
|
| 409 |
-
2. If the model's tokenizer has a single template or no template:
|
| 410 |
-
a. Use the tokenizer's chat template if available.
|
| 411 |
-
b. Fall back to the default chat template if no tokenizer chat template exists.
|
| 412 |
-
|
| 413 |
-
Args:
|
| 414 |
-
chat_template (Union[bool, str]): Specifies the chat template to use.
|
| 415 |
-
- If False or None, no template is applied.
|
| 416 |
-
- If True, the default or only available template is used.
|
| 417 |
-
- If a string, the template with the matching name is used.
|
| 418 |
-
|
| 419 |
-
Returns:
|
| 420 |
-
Optional[str]: The selected chat template, or None if no template is applied.
|
| 421 |
-
"""
|
| 422 |
-
if self.tokenizer is None:
|
| 423 |
-
return ""
|
| 424 |
-
|
| 425 |
-
if chat_template is False or chat_template is None:
|
| 426 |
-
eval_logger.warning(
|
| 427 |
-
"model.chat_template was called with the chat_template set to False or None. "
|
| 428 |
-
"Therefore no chat template will be applied. Make sure this is an intended behavior."
|
| 429 |
-
)
|
| 430 |
-
return None
|
| 431 |
-
|
| 432 |
-
# Convert boolean chat_template to None to ensure compatibility with the adapted logic
|
| 433 |
-
if isinstance(chat_template, bool):
|
| 434 |
-
chat_template = None
|
| 435 |
-
using_default_template = False
|
| 436 |
-
|
| 437 |
-
# First, handle the cases when the model has a dict of multiple templates
|
| 438 |
-
try:
|
| 439 |
-
template = (
|
| 440 |
-
self.tokenizer.chat_template or self.tokenizer.default_chat_template
|
| 441 |
-
)
|
| 442 |
-
except AttributeError:
|
| 443 |
-
return None
|
| 444 |
-
|
| 445 |
-
if isinstance(template, dict):
|
| 446 |
-
using_default_dict = self.tokenizer.chat_template is None
|
| 447 |
-
|
| 448 |
-
if chat_template is not None:
|
| 449 |
-
if chat_template in template:
|
| 450 |
-
selected_template = template[chat_template]
|
| 451 |
-
if using_default_dict:
|
| 452 |
-
using_default_template = True
|
| 453 |
-
else:
|
| 454 |
-
raise ValueError(
|
| 455 |
-
f"The specified chat template '{chat_template}' is not available. "
|
| 456 |
-
f"Available template names are {sorted(template.keys())}."
|
| 457 |
-
)
|
| 458 |
-
else:
|
| 459 |
-
# If user didn't pass a chat template, use the default template from the dict
|
| 460 |
-
if "default" in template:
|
| 461 |
-
selected_template = template["default"]
|
| 462 |
-
using_default_template = True
|
| 463 |
-
else:
|
| 464 |
-
raise ValueError(
|
| 465 |
-
"This model has multiple chat templates with no default specified! Please either pass a chat "
|
| 466 |
-
"template or the name of the template you wish to use to the `chat_template` argument. Available "
|
| 467 |
-
f"template names are {sorted(template.keys())}."
|
| 468 |
-
)
|
| 469 |
-
|
| 470 |
-
# Cases when the model has a single template or no template
|
| 471 |
-
else:
|
| 472 |
-
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
|
| 473 |
-
if isinstance(chat_template, str):
|
| 474 |
-
eval_logger.warning(
|
| 475 |
-
"Chat template name provided, but the tokenizer's chat template is not a dictionary. "
|
| 476 |
-
"Using the tokenizer's chat template or the default template instead."
|
| 477 |
-
)
|
| 478 |
-
if self.tokenizer.chat_template is not None:
|
| 479 |
-
selected_template = self.tokenizer.chat_template
|
| 480 |
-
else:
|
| 481 |
-
selected_template = self.tokenizer.default_chat_template
|
| 482 |
-
using_default_template = True
|
| 483 |
-
|
| 484 |
-
if using_default_template:
|
| 485 |
-
eval_logger.warning(
|
| 486 |
-
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
|
| 487 |
-
"very error-prone, because models are often trained with templates different from the class default! "
|
| 488 |
-
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
|
| 489 |
-
"point any code depending on them will stop working. We recommend setting a valid chat template before "
|
| 490 |
-
"then to ensure that this model continues working without issues."
|
| 491 |
-
)
|
| 492 |
-
|
| 493 |
-
return selected_template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/registry.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
from typing import Callable, Dict, Union
|
| 3 |
-
|
| 4 |
-
import evaluate as hf_evaluate
|
| 5 |
-
|
| 6 |
-
from lm_eval.api.model import LM
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
eval_logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
MODEL_REGISTRY = {}
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def register_model(*names):
|
| 15 |
-
# either pass a list or a single alias.
|
| 16 |
-
# function receives them as a tuple of strings
|
| 17 |
-
|
| 18 |
-
def decorate(cls):
|
| 19 |
-
for name in names:
|
| 20 |
-
assert issubclass(cls, LM), (
|
| 21 |
-
f"Model '{name}' ({cls.__name__}) must extend LM class"
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
assert name not in MODEL_REGISTRY, (
|
| 25 |
-
f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
MODEL_REGISTRY[name] = cls
|
| 29 |
-
return cls
|
| 30 |
-
|
| 31 |
-
return decorate
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def get_model(model_name):
|
| 35 |
-
try:
|
| 36 |
-
return MODEL_REGISTRY[model_name]
|
| 37 |
-
except KeyError:
|
| 38 |
-
raise ValueError(
|
| 39 |
-
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
TASK_REGISTRY = {}
|
| 44 |
-
GROUP_REGISTRY = {}
|
| 45 |
-
ALL_TASKS = set()
|
| 46 |
-
func2task_index = {}
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def register_task(name):
|
| 50 |
-
def decorate(fn):
|
| 51 |
-
assert name not in TASK_REGISTRY, (
|
| 52 |
-
f"task named '{name}' conflicts with existing registered task!"
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
TASK_REGISTRY[name] = fn
|
| 56 |
-
ALL_TASKS.add(name)
|
| 57 |
-
func2task_index[fn.__name__] = name
|
| 58 |
-
return fn
|
| 59 |
-
|
| 60 |
-
return decorate
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def register_group(name):
|
| 64 |
-
def decorate(fn):
|
| 65 |
-
func_name = func2task_index[fn.__name__]
|
| 66 |
-
if name in GROUP_REGISTRY:
|
| 67 |
-
GROUP_REGISTRY[name].append(func_name)
|
| 68 |
-
else:
|
| 69 |
-
GROUP_REGISTRY[name] = [func_name]
|
| 70 |
-
ALL_TASKS.add(name)
|
| 71 |
-
return fn
|
| 72 |
-
|
| 73 |
-
return decorate
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
OUTPUT_TYPE_REGISTRY = {}
|
| 77 |
-
METRIC_REGISTRY = {}
|
| 78 |
-
METRIC_AGGREGATION_REGISTRY = {}
|
| 79 |
-
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
|
| 80 |
-
HIGHER_IS_BETTER_REGISTRY = {}
|
| 81 |
-
FILTER_REGISTRY = {}
|
| 82 |
-
|
| 83 |
-
DEFAULT_METRIC_REGISTRY = {
|
| 84 |
-
"loglikelihood": [
|
| 85 |
-
"perplexity",
|
| 86 |
-
"acc",
|
| 87 |
-
],
|
| 88 |
-
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
|
| 89 |
-
"multiple_choice": ["acc", "acc_norm"],
|
| 90 |
-
"generate_until": ["exact_match"],
|
| 91 |
-
}
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def register_metric(**args):
|
| 95 |
-
# TODO: do we want to enforce a certain interface to registered metrics?
|
| 96 |
-
def decorate(fn):
|
| 97 |
-
assert "metric" in args
|
| 98 |
-
name = args["metric"]
|
| 99 |
-
|
| 100 |
-
for key, registry in [
|
| 101 |
-
("metric", METRIC_REGISTRY),
|
| 102 |
-
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
|
| 103 |
-
("aggregation", METRIC_AGGREGATION_REGISTRY),
|
| 104 |
-
]:
|
| 105 |
-
if key in args:
|
| 106 |
-
value = args[key]
|
| 107 |
-
assert value not in registry, (
|
| 108 |
-
f"{key} named '{value}' conflicts with existing registered {key}!"
|
| 109 |
-
)
|
| 110 |
-
|
| 111 |
-
if key == "metric":
|
| 112 |
-
registry[name] = fn
|
| 113 |
-
elif key == "aggregation":
|
| 114 |
-
registry[name] = AGGREGATION_REGISTRY[value]
|
| 115 |
-
else:
|
| 116 |
-
registry[name] = value
|
| 117 |
-
|
| 118 |
-
return fn
|
| 119 |
-
|
| 120 |
-
return decorate
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
def get_metric(name: str, hf_evaluate_metric=False) -> Callable:
|
| 124 |
-
if not hf_evaluate_metric:
|
| 125 |
-
if name in METRIC_REGISTRY:
|
| 126 |
-
return METRIC_REGISTRY[name]
|
| 127 |
-
else:
|
| 128 |
-
eval_logger.warning(
|
| 129 |
-
f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
try:
|
| 133 |
-
metric_object = hf_evaluate.load(name)
|
| 134 |
-
return metric_object.compute
|
| 135 |
-
except Exception:
|
| 136 |
-
eval_logger.error(
|
| 137 |
-
f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def register_aggregation(name: str):
|
| 142 |
-
def decorate(fn):
|
| 143 |
-
assert name not in AGGREGATION_REGISTRY, (
|
| 144 |
-
f"aggregation named '{name}' conflicts with existing registered aggregation!"
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
AGGREGATION_REGISTRY[name] = fn
|
| 148 |
-
return fn
|
| 149 |
-
|
| 150 |
-
return decorate
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
def get_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
| 154 |
-
try:
|
| 155 |
-
return AGGREGATION_REGISTRY[name]
|
| 156 |
-
except KeyError:
|
| 157 |
-
eval_logger.warning(f"{name} not a registered aggregation metric!")
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def get_metric_aggregation(name: str) -> Callable[[], Dict[str, Callable]]:
|
| 161 |
-
try:
|
| 162 |
-
return METRIC_AGGREGATION_REGISTRY[name]
|
| 163 |
-
except KeyError:
|
| 164 |
-
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def is_higher_better(metric_name) -> bool:
|
| 168 |
-
try:
|
| 169 |
-
return HIGHER_IS_BETTER_REGISTRY[metric_name]
|
| 170 |
-
except KeyError:
|
| 171 |
-
eval_logger.warning(
|
| 172 |
-
f"higher_is_better not specified for metric '{metric_name}'!"
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def register_filter(name):
|
| 177 |
-
def decorate(cls):
|
| 178 |
-
if name in FILTER_REGISTRY:
|
| 179 |
-
eval_logger.info(
|
| 180 |
-
f"Registering filter `{name}` that is already in Registry {FILTER_REGISTRY}"
|
| 181 |
-
)
|
| 182 |
-
FILTER_REGISTRY[name] = cls
|
| 183 |
-
return cls
|
| 184 |
-
|
| 185 |
-
return decorate
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
def get_filter(filter_name: Union[str, Callable]) -> Callable:
|
| 189 |
-
try:
|
| 190 |
-
return FILTER_REGISTRY[filter_name]
|
| 191 |
-
except KeyError as e:
|
| 192 |
-
if callable(filter_name):
|
| 193 |
-
return filter_name
|
| 194 |
-
else:
|
| 195 |
-
eval_logger.warning(f"filter `{filter_name}` is not registered!")
|
| 196 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/samplers.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import warnings
|
| 3 |
-
from functools import partial
|
| 4 |
-
from typing import TYPE_CHECKING, Iterable, Optional, Union
|
| 5 |
-
|
| 6 |
-
import datasets
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
if TYPE_CHECKING:
|
| 10 |
-
from random import Random
|
| 11 |
-
|
| 12 |
-
from lm_eval.api.task import ConfigurableTask, Task
|
| 13 |
-
|
| 14 |
-
eval_logger = logging.getLogger("lm-eval")
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class ContextSampler:
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
docs: list[dict],
|
| 21 |
-
task: Union["Task", "ConfigurableTask"],
|
| 22 |
-
fewshot_indices: Optional[Iterable] = None,
|
| 23 |
-
rnd: Optional["Random"] = None,
|
| 24 |
-
) -> None:
|
| 25 |
-
self.rnd = rnd
|
| 26 |
-
if not self.rnd:
|
| 27 |
-
raise ValueError(
|
| 28 |
-
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
-
self.task = task
|
| 32 |
-
self.config = task._config
|
| 33 |
-
|
| 34 |
-
self.target_delimiter = self.config.target_delimiter
|
| 35 |
-
self.fewshot_delimiter = self.config.fewshot_delimiter
|
| 36 |
-
|
| 37 |
-
if (
|
| 38 |
-
self.config.fewshot_config is not None
|
| 39 |
-
and self.config.fewshot_config.get("doc_to_text", None) is not None
|
| 40 |
-
):
|
| 41 |
-
self.doc_to_text = partial(
|
| 42 |
-
self.task.doc_to_text,
|
| 43 |
-
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
|
| 44 |
-
)
|
| 45 |
-
else:
|
| 46 |
-
self.doc_to_text = self.task.doc_to_text
|
| 47 |
-
|
| 48 |
-
if (
|
| 49 |
-
self.config.fewshot_config is not None
|
| 50 |
-
and self.config.fewshot_config.get("doc_to_target", None) is not None
|
| 51 |
-
):
|
| 52 |
-
self.doc_to_target = partial(
|
| 53 |
-
self.task.doc_to_target,
|
| 54 |
-
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
|
| 55 |
-
)
|
| 56 |
-
else:
|
| 57 |
-
self.doc_to_target = self.task.doc_to_target
|
| 58 |
-
|
| 59 |
-
if (
|
| 60 |
-
self.config.fewshot_config is not None
|
| 61 |
-
and self.config.fewshot_config.get("doc_to_choice", None) is not None
|
| 62 |
-
):
|
| 63 |
-
self.doc_to_choice = partial(
|
| 64 |
-
self.task.doc_to_choice,
|
| 65 |
-
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
|
| 66 |
-
)
|
| 67 |
-
else:
|
| 68 |
-
self.doc_to_choice = self.task.doc_to_choice
|
| 69 |
-
|
| 70 |
-
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
|
| 71 |
-
if fewshot_indices: # subset few-shot docs from
|
| 72 |
-
if not isinstance(self.docs, datasets.Dataset):
|
| 73 |
-
raise ValueError(
|
| 74 |
-
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
|
| 75 |
-
)
|
| 76 |
-
self.docs = self.docs.select(fewshot_indices)
|
| 77 |
-
|
| 78 |
-
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
|
| 79 |
-
# draw an extra fewshot sample if using same split as evaluating on
|
| 80 |
-
prefix = gen_prefix + " " if gen_prefix else ""
|
| 81 |
-
n_samples = (
|
| 82 |
-
num_fewshot + 1
|
| 83 |
-
if self.config.fewshot_split == self.config.test_split
|
| 84 |
-
else num_fewshot
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# draw `n_samples` docs from fewshot_docs
|
| 88 |
-
fewshotex = self.sample(n_samples)
|
| 89 |
-
|
| 90 |
-
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 91 |
-
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 92 |
-
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 93 |
-
|
| 94 |
-
labeled_examples = ""
|
| 95 |
-
for doc in selected_docs:
|
| 96 |
-
doc_content = self.doc_to_text(doc)
|
| 97 |
-
doc_target = self.doc_to_target(doc)
|
| 98 |
-
if self.config.doc_to_choice is None or isinstance(doc_content, str):
|
| 99 |
-
labeled_examples += doc_content
|
| 100 |
-
else:
|
| 101 |
-
labeled_examples += self.doc_to_choice(doc)[doc_content]
|
| 102 |
-
|
| 103 |
-
if doc_target != "":
|
| 104 |
-
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
|
| 105 |
-
# TODO: add logger warn once here.
|
| 106 |
-
warnings.warn(
|
| 107 |
-
"Both target_delimiter and target start with a space. This may cause issues.",
|
| 108 |
-
Warning,
|
| 109 |
-
stacklevel=2,
|
| 110 |
-
)
|
| 111 |
-
labeled_examples += self.target_delimiter
|
| 112 |
-
labeled_examples += prefix
|
| 113 |
-
labeled_examples += (
|
| 114 |
-
str(doc_target[0])
|
| 115 |
-
if isinstance(doc_target, list)
|
| 116 |
-
else doc_target
|
| 117 |
-
if self.config.doc_to_choice is None or isinstance(doc_target, str)
|
| 118 |
-
else str(self.doc_to_choice(doc)[doc_target])
|
| 119 |
-
)
|
| 120 |
-
labeled_examples += self.fewshot_delimiter
|
| 121 |
-
|
| 122 |
-
return labeled_examples
|
| 123 |
-
|
| 124 |
-
def get_chat_context(
|
| 125 |
-
self,
|
| 126 |
-
doc: dict,
|
| 127 |
-
num_fewshot: int,
|
| 128 |
-
fewshot_as_multiturn: bool = False,
|
| 129 |
-
gen_prefix: Optional[str] = None,
|
| 130 |
-
):
|
| 131 |
-
# TODO: Do we need any other delimiter
|
| 132 |
-
prefix = gen_prefix + " " if gen_prefix else ""
|
| 133 |
-
chat_history = []
|
| 134 |
-
# draw an extra fewshot sample if using same split as evaluating on
|
| 135 |
-
n_samples = (
|
| 136 |
-
num_fewshot + 1
|
| 137 |
-
if self.config.fewshot_split == self.config.test_split
|
| 138 |
-
else num_fewshot
|
| 139 |
-
)
|
| 140 |
-
# draw `n_samples` docs from fewshot_docs
|
| 141 |
-
fewshotex = self.sample(n_samples)
|
| 142 |
-
|
| 143 |
-
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 144 |
-
# TODO: should we just stop people from using fewshot from same split as evaluating?
|
| 145 |
-
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 146 |
-
|
| 147 |
-
if fewshot_as_multiturn:
|
| 148 |
-
for doc in selected_docs:
|
| 149 |
-
doc_content = self.doc_to_text(doc)
|
| 150 |
-
doc_target = self.doc_to_target(doc)
|
| 151 |
-
chat_history.append(
|
| 152 |
-
{
|
| 153 |
-
"role": "user",
|
| 154 |
-
"content": doc_content
|
| 155 |
-
if self.config.doc_to_choice is None
|
| 156 |
-
or isinstance(doc_content, str)
|
| 157 |
-
else self.doc_to_choice(doc)[doc_content],
|
| 158 |
-
}
|
| 159 |
-
)
|
| 160 |
-
chat_history.append(
|
| 161 |
-
{
|
| 162 |
-
"role": "assistant",
|
| 163 |
-
"content": prefix + str(doc_target[0])
|
| 164 |
-
if isinstance(doc_target, list)
|
| 165 |
-
else prefix + doc_target
|
| 166 |
-
if self.config.doc_to_choice is None
|
| 167 |
-
or isinstance(doc_target, str)
|
| 168 |
-
else prefix + str(self.doc_to_choice(doc)[doc_target]),
|
| 169 |
-
}
|
| 170 |
-
)
|
| 171 |
-
else:
|
| 172 |
-
# get fewshot context as one user turn
|
| 173 |
-
chat_history.append(
|
| 174 |
-
{
|
| 175 |
-
"role": "user",
|
| 176 |
-
"content": self.get_context(
|
| 177 |
-
doc, num_fewshot, gen_prefix=gen_prefix
|
| 178 |
-
),
|
| 179 |
-
}
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
return chat_history
|
| 183 |
-
|
| 184 |
-
def sample(self, n: int):
|
| 185 |
-
"""
|
| 186 |
-
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
|
| 187 |
-
"""
|
| 188 |
-
|
| 189 |
-
return self.rnd.sample(self.docs, n)
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
class FirstNSampler(ContextSampler):
|
| 193 |
-
def sample(self, n: int) -> None:
|
| 194 |
-
"""
|
| 195 |
-
Draw the first `n` samples in order from the specified split.
|
| 196 |
-
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
|
| 197 |
-
"""
|
| 198 |
-
assert n <= len(self.docs), (
|
| 199 |
-
f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
|
| 200 |
-
)
|
| 201 |
-
return self.docs[:n]
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
class BalancedSampler(ContextSampler):
|
| 205 |
-
def sample(self, n: int) -> None:
|
| 206 |
-
"""
|
| 207 |
-
TODO: this should return approximately class-balanced samples from our fewshot examples.
|
| 208 |
-
TODO: what order should they be in? maybe random?
|
| 209 |
-
"""
|
| 210 |
-
|
| 211 |
-
pass
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class ManualSampler(ContextSampler):
|
| 215 |
-
def sample(self, n: int) -> None:
|
| 216 |
-
""" """
|
| 217 |
-
pass
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
SAMPLER_REGISTRY = {
|
| 221 |
-
"default": ContextSampler,
|
| 222 |
-
"first_n": FirstNSampler,
|
| 223 |
-
}
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
def get_sampler(name: str):
|
| 227 |
-
try:
|
| 228 |
-
return SAMPLER_REGISTRY[name]
|
| 229 |
-
except KeyError:
|
| 230 |
-
raise ValueError(
|
| 231 |
-
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
|
| 232 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/api/task.py
DELETED
|
@@ -1,1879 +0,0 @@
|
|
| 1 |
-
import abc
|
| 2 |
-
import ast
|
| 3 |
-
import logging
|
| 4 |
-
import random
|
| 5 |
-
import re
|
| 6 |
-
from collections.abc import Callable
|
| 7 |
-
from copy import deepcopy
|
| 8 |
-
from dataclasses import asdict, dataclass
|
| 9 |
-
from inspect import getsource
|
| 10 |
-
from typing import (
|
| 11 |
-
Any,
|
| 12 |
-
Dict,
|
| 13 |
-
Iterable,
|
| 14 |
-
Iterator,
|
| 15 |
-
List,
|
| 16 |
-
Literal,
|
| 17 |
-
Mapping,
|
| 18 |
-
Optional,
|
| 19 |
-
Tuple,
|
| 20 |
-
Union,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
import datasets
|
| 24 |
-
import numpy as np
|
| 25 |
-
from tqdm import tqdm
|
| 26 |
-
|
| 27 |
-
from lm_eval import utils
|
| 28 |
-
from lm_eval.api import samplers
|
| 29 |
-
from lm_eval.api.instance import Instance, OutputType
|
| 30 |
-
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
|
| 31 |
-
from lm_eval.api.registry import (
|
| 32 |
-
AGGREGATION_REGISTRY,
|
| 33 |
-
DEFAULT_METRIC_REGISTRY,
|
| 34 |
-
get_aggregation,
|
| 35 |
-
get_metric,
|
| 36 |
-
get_metric_aggregation,
|
| 37 |
-
is_higher_better,
|
| 38 |
-
)
|
| 39 |
-
from lm_eval.caching.cache import load_from_cache, save_to_cache
|
| 40 |
-
from lm_eval.filters import build_filter_ensemble
|
| 41 |
-
from lm_eval.prompts import get_prompt
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
ALL_OUTPUT_TYPES = [
|
| 45 |
-
"loglikelihood",
|
| 46 |
-
"multiple_choice",
|
| 47 |
-
"loglikelihood_rolling",
|
| 48 |
-
"generate_until",
|
| 49 |
-
]
|
| 50 |
-
|
| 51 |
-
eval_logger = logging.getLogger(__name__)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
@dataclass
|
| 55 |
-
class TaskConfig(dict):
|
| 56 |
-
# task naming/registry
|
| 57 |
-
task: Optional[str] = None
|
| 58 |
-
task_alias: Optional[str] = None
|
| 59 |
-
tag: Optional[Union[str, list]] = None
|
| 60 |
-
# HF dataset options.
|
| 61 |
-
# which dataset to use,
|
| 62 |
-
# and what splits for what purpose
|
| 63 |
-
custom_dataset: Optional[Callable] = None
|
| 64 |
-
dataset_path: Optional[str] = None
|
| 65 |
-
dataset_name: Optional[str] = None
|
| 66 |
-
dataset_kwargs: Optional[dict] = None
|
| 67 |
-
training_split: Optional[str] = None
|
| 68 |
-
validation_split: Optional[str] = None
|
| 69 |
-
test_split: Optional[str] = None
|
| 70 |
-
fewshot_split: Optional[str] = (
|
| 71 |
-
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
|
| 72 |
-
)
|
| 73 |
-
# formatting / prompting options.
|
| 74 |
-
# see docs/advanced_task_guide.md for more info
|
| 75 |
-
process_docs: Optional[Callable] = None
|
| 76 |
-
doc_to_text: Optional[Union[Callable, str]] = None
|
| 77 |
-
doc_to_target: Optional[Union[Callable, str]] = None
|
| 78 |
-
doc_to_image: Union[Callable, str] = None
|
| 79 |
-
doc_to_audio: Union[Callable, str] = None
|
| 80 |
-
unsafe_code: bool = False
|
| 81 |
-
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
|
| 82 |
-
process_results: Optional[Union[Callable, str]] = None
|
| 83 |
-
use_prompt: Optional[str] = None
|
| 84 |
-
description: str = ""
|
| 85 |
-
target_delimiter: str = " "
|
| 86 |
-
fewshot_delimiter: str = "\n\n"
|
| 87 |
-
fewshot_config: Optional[dict] = None
|
| 88 |
-
# runtime configuration options
|
| 89 |
-
num_fewshot: Optional[int] = None
|
| 90 |
-
# scoring options
|
| 91 |
-
metric_list: Optional[list] = None
|
| 92 |
-
output_type: OutputType = "generate_until"
|
| 93 |
-
generation_kwargs: Optional[dict] = None
|
| 94 |
-
repeats: int = 1
|
| 95 |
-
filter_list: Optional[Union[str, list]] = None
|
| 96 |
-
should_decontaminate: bool = False
|
| 97 |
-
doc_to_decontamination_query: Optional[str] = None
|
| 98 |
-
gen_prefix: Optional[str] = None
|
| 99 |
-
metadata: Optional[dict] = (
|
| 100 |
-
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
|
| 101 |
-
)
|
| 102 |
-
|
| 103 |
-
def __post_init__(self) -> None:
|
| 104 |
-
if self.generation_kwargs is not None:
|
| 105 |
-
if self.output_type != "generate_until":
|
| 106 |
-
eval_logger.warning(
|
| 107 |
-
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
if "temperature" in self.generation_kwargs:
|
| 111 |
-
self.generation_kwargs["temperature"] = float(
|
| 112 |
-
self.generation_kwargs["temperature"]
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
if "until" not in self.generation_kwargs:
|
| 116 |
-
eval_logger.warning(
|
| 117 |
-
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
|
| 118 |
-
)
|
| 119 |
-
self.generation_kwargs["until"] = [self.fewshot_delimiter]
|
| 120 |
-
else:
|
| 121 |
-
if self.output_type == "generate_until":
|
| 122 |
-
# ensure that we greedily generate in absence of explicit arguments otherwise
|
| 123 |
-
self.generation_kwargs = {
|
| 124 |
-
"until": (
|
| 125 |
-
None
|
| 126 |
-
if self.fewshot_delimiter is None
|
| 127 |
-
else [self.fewshot_delimiter]
|
| 128 |
-
),
|
| 129 |
-
"do_sample": False,
|
| 130 |
-
"temperature": 0,
|
| 131 |
-
}
|
| 132 |
-
eval_logger.warning(
|
| 133 |
-
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
def __getitem__(self, item):
|
| 137 |
-
return getattr(self, item)
|
| 138 |
-
|
| 139 |
-
def __setitem__(self, item, value):
|
| 140 |
-
return setattr(self, item, value)
|
| 141 |
-
|
| 142 |
-
def to_dict(self, keep_callable: bool = False) -> dict:
|
| 143 |
-
"""dumps the current config as a dictionary object, as a printable format.
|
| 144 |
-
null fields will not be printed.
|
| 145 |
-
Used for dumping results alongside full task configuration
|
| 146 |
-
|
| 147 |
-
:return: dict
|
| 148 |
-
A printable dictionary version of the TaskConfig object.
|
| 149 |
-
|
| 150 |
-
# TODO: should any default value in the TaskConfig not be printed?
|
| 151 |
-
"""
|
| 152 |
-
cfg_dict = asdict(self)
|
| 153 |
-
# remove values that are `None`
|
| 154 |
-
for k, v in list(cfg_dict.items()):
|
| 155 |
-
if v is None:
|
| 156 |
-
cfg_dict.pop(k)
|
| 157 |
-
elif k == "metric_list":
|
| 158 |
-
for metric_dict in v:
|
| 159 |
-
for metric_key, metric_value in metric_dict.items():
|
| 160 |
-
if callable(metric_value):
|
| 161 |
-
metric_dict[metric_key] = self.serialize_function(
|
| 162 |
-
metric_value, keep_callable=keep_callable
|
| 163 |
-
)
|
| 164 |
-
cfg_dict[k] = v
|
| 165 |
-
elif callable(v):
|
| 166 |
-
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
|
| 167 |
-
return cfg_dict
|
| 168 |
-
|
| 169 |
-
def serialize_function(
|
| 170 |
-
self, value: Union[Callable, str], keep_callable=False
|
| 171 |
-
) -> Union[Callable, str]:
|
| 172 |
-
"""Serializes a given function or string.
|
| 173 |
-
|
| 174 |
-
If 'keep_callable' is True, the original callable is returned.
|
| 175 |
-
Otherwise, attempts to return the source code of the callable using 'getsource'.
|
| 176 |
-
"""
|
| 177 |
-
if keep_callable:
|
| 178 |
-
return value
|
| 179 |
-
else:
|
| 180 |
-
try:
|
| 181 |
-
return getsource(value)
|
| 182 |
-
except (TypeError, OSError):
|
| 183 |
-
return str(value)
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class Task(abc.ABC):
|
| 187 |
-
"""A task represents an entire benchmark including its dataset, problems,
|
| 188 |
-
answers, and evaluation methods. See BoolQ for a simple example implementation
|
| 189 |
-
|
| 190 |
-
A `doc` can be any python object which represents one instance of evaluation.
|
| 191 |
-
This is usually a dictionary e.g.
|
| 192 |
-
{"question": ..., "answer": ...} or
|
| 193 |
-
{"question": ..., question, answer)
|
| 194 |
-
"""
|
| 195 |
-
|
| 196 |
-
VERSION: Optional[Union[int, str]] = None
|
| 197 |
-
|
| 198 |
-
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
| 199 |
-
# or a path to a custom `datasets` loading script.
|
| 200 |
-
DATASET_PATH: Optional[str] = None
|
| 201 |
-
|
| 202 |
-
# The name of a subset within `DATASET_PATH`.
|
| 203 |
-
DATASET_NAME: Optional[str] = None
|
| 204 |
-
|
| 205 |
-
OUTPUT_TYPE: Optional[OutputType] = None
|
| 206 |
-
|
| 207 |
-
def __init__(
|
| 208 |
-
self,
|
| 209 |
-
data_dir: Optional[str] = None,
|
| 210 |
-
cache_dir: Optional[str] = None,
|
| 211 |
-
download_mode: Optional[datasets.DownloadMode] = None,
|
| 212 |
-
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
|
| 213 |
-
) -> None:
|
| 214 |
-
"""
|
| 215 |
-
:param data_dir: str
|
| 216 |
-
Stores the path to a local folder containing the `Task`'s data files.
|
| 217 |
-
Use this to specify the path to manually downloaded data (usually when
|
| 218 |
-
the dataset is not publicly accessible).
|
| 219 |
-
:param cache_dir: str
|
| 220 |
-
The directory to read/write the `Task` dataset. This follows the
|
| 221 |
-
HuggingFace `datasets` API with the default cache directory located at:
|
| 222 |
-
`~/.cache/huggingface/datasets`
|
| 223 |
-
NOTE: You can change the cache location globally for a given process
|
| 224 |
-
to another directory:
|
| 225 |
-
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 226 |
-
:param download_mode: datasets.DownloadMode
|
| 227 |
-
How to treat pre-existing `Task` downloads and data.
|
| 228 |
-
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 229 |
-
Reuse download and reuse dataset.
|
| 230 |
-
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 231 |
-
Reuse download with fresh dataset.
|
| 232 |
-
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 233 |
-
Fresh download and fresh dataset.
|
| 234 |
-
"""
|
| 235 |
-
self.download(data_dir, cache_dir, download_mode)
|
| 236 |
-
self._training_docs: Optional[list] = None
|
| 237 |
-
self._fewshot_docs: Optional[list] = None
|
| 238 |
-
self._instances: Optional[List[Instance]] = None
|
| 239 |
-
|
| 240 |
-
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
|
| 241 |
-
|
| 242 |
-
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 243 |
-
self.fewshot_rnd: Optional[random.Random] = (
|
| 244 |
-
None # purposely induce errors in case of improper usage
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
def download(
|
| 248 |
-
self,
|
| 249 |
-
data_dir: Optional[str] = None,
|
| 250 |
-
cache_dir: Optional[str] = None,
|
| 251 |
-
download_mode=None,
|
| 252 |
-
) -> None:
|
| 253 |
-
"""Downloads and returns the task dataset.
|
| 254 |
-
Override this method to download the dataset from a custom API.
|
| 255 |
-
|
| 256 |
-
:param data_dir: str
|
| 257 |
-
Stores the path to a local folder containing the `Task`'s data files.
|
| 258 |
-
Use this to specify the path to manually downloaded data (usually when
|
| 259 |
-
the dataset is not publicly accessible).
|
| 260 |
-
:param cache_dir: str
|
| 261 |
-
The directory to read/write the `Task` dataset. This follows the
|
| 262 |
-
HuggingFace `datasets` API with the default cache directory located at:
|
| 263 |
-
`~/.cache/huggingface/datasets`
|
| 264 |
-
NOTE: You can change the cache location globally for a given process
|
| 265 |
-
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 266 |
-
to another directory:
|
| 267 |
-
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 268 |
-
:param download_mode: datasets.DownloadMode
|
| 269 |
-
How to treat pre-existing `Task` downloads and data.
|
| 270 |
-
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 271 |
-
Reuse download and reuse dataset.
|
| 272 |
-
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 273 |
-
Reuse download with fresh dataset.
|
| 274 |
-
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 275 |
-
Fresh download and fresh dataset.
|
| 276 |
-
"""
|
| 277 |
-
self.dataset = datasets.load_dataset(
|
| 278 |
-
path=self.DATASET_PATH,
|
| 279 |
-
name=self.DATASET_NAME,
|
| 280 |
-
data_dir=data_dir,
|
| 281 |
-
cache_dir=cache_dir,
|
| 282 |
-
download_mode=download_mode,
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
@property
|
| 286 |
-
def config(self) -> TaskConfig:
|
| 287 |
-
"""Returns the TaskConfig associated with this class."""
|
| 288 |
-
return self._config
|
| 289 |
-
|
| 290 |
-
@abc.abstractmethod
|
| 291 |
-
def has_training_docs(self):
|
| 292 |
-
"""Whether the task has a training set"""
|
| 293 |
-
pass
|
| 294 |
-
|
| 295 |
-
@abc.abstractmethod
|
| 296 |
-
def has_validation_docs(self):
|
| 297 |
-
"""Whether the task has a validation set"""
|
| 298 |
-
pass
|
| 299 |
-
|
| 300 |
-
@abc.abstractmethod
|
| 301 |
-
def has_test_docs(self):
|
| 302 |
-
"""Whether the task has a test set"""
|
| 303 |
-
pass
|
| 304 |
-
|
| 305 |
-
def training_docs(self) -> Iterable:
|
| 306 |
-
"""
|
| 307 |
-
:return: Iterable[obj]
|
| 308 |
-
A iterable of any object, that doc_to_text can handle
|
| 309 |
-
"""
|
| 310 |
-
return []
|
| 311 |
-
|
| 312 |
-
def validation_docs(self) -> Iterable:
|
| 313 |
-
"""
|
| 314 |
-
:return: Iterable[obj]
|
| 315 |
-
A iterable of any object, that doc_to_text can handle
|
| 316 |
-
"""
|
| 317 |
-
return []
|
| 318 |
-
|
| 319 |
-
def test_docs(self) -> Iterable:
|
| 320 |
-
"""
|
| 321 |
-
:return: Iterable[obj]
|
| 322 |
-
A iterable of any object, that doc_to_text can handle
|
| 323 |
-
"""
|
| 324 |
-
return []
|
| 325 |
-
|
| 326 |
-
def fewshot_docs(self) -> Iterable:
|
| 327 |
-
"""
|
| 328 |
-
:return: Iterable[obj]
|
| 329 |
-
A iterable of any object, that doc_to_text can handle
|
| 330 |
-
"""
|
| 331 |
-
if self.has_training_docs():
|
| 332 |
-
return self.training_docs()
|
| 333 |
-
elif self.has_validation_docs():
|
| 334 |
-
return self.validation_docs()
|
| 335 |
-
else:
|
| 336 |
-
if self.config.get("num_fewshot", 0) > 0:
|
| 337 |
-
eval_logger.warning(
|
| 338 |
-
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
|
| 339 |
-
", using test_docs as fewshot_docs but this is not recommended."
|
| 340 |
-
)
|
| 341 |
-
return self.test_docs()
|
| 342 |
-
|
| 343 |
-
def _process_doc(self, doc: dict) -> dict:
|
| 344 |
-
"""
|
| 345 |
-
Override this to process (detokenize, strip, replace, etc.) individual
|
| 346 |
-
documents. This can be used in a map over documents of a data split.
|
| 347 |
-
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 348 |
-
|
| 349 |
-
:return: dict
|
| 350 |
-
The processed version of the specified `doc`.
|
| 351 |
-
"""
|
| 352 |
-
return doc
|
| 353 |
-
|
| 354 |
-
@property
|
| 355 |
-
def instances(self) -> List[Instance]:
|
| 356 |
-
"""After calling `task.build_all_requests()`, tasks
|
| 357 |
-
maintain a list of the dataset instances which will be evaluated.
|
| 358 |
-
"""
|
| 359 |
-
return self._instances
|
| 360 |
-
|
| 361 |
-
def fewshot_examples(self, k, rnd):
|
| 362 |
-
if self._training_docs is None:
|
| 363 |
-
self._training_docs = list(self.training_docs())
|
| 364 |
-
|
| 365 |
-
return rnd.sample(self._training_docs, k)
|
| 366 |
-
|
| 367 |
-
def doc_to_decontamination_query(self, doc):
|
| 368 |
-
raise NotImplementedError(
|
| 369 |
-
"Override doc_to_decontamination_query with document specific decontamination query."
|
| 370 |
-
)
|
| 371 |
-
|
| 372 |
-
@abc.abstractmethod
|
| 373 |
-
def doc_to_text(self, doc):
|
| 374 |
-
pass
|
| 375 |
-
|
| 376 |
-
@abc.abstractmethod
|
| 377 |
-
def doc_to_target(self, doc):
|
| 378 |
-
pass
|
| 379 |
-
|
| 380 |
-
# not an abstractmethod because not every language-only task has to implement this
|
| 381 |
-
def doc_to_image(self, doc):
|
| 382 |
-
raise NotImplementedError
|
| 383 |
-
|
| 384 |
-
def doc_to_audio(self, doc):
|
| 385 |
-
raise NotImplementedError
|
| 386 |
-
|
| 387 |
-
def doc_to_prefix(self, doc):
|
| 388 |
-
return ""
|
| 389 |
-
|
| 390 |
-
def build_all_requests(
|
| 391 |
-
self,
|
| 392 |
-
*,
|
| 393 |
-
limit: Union[int, None] = None,
|
| 394 |
-
samples: Optional[List[int]] = None,
|
| 395 |
-
rank: int = 0,
|
| 396 |
-
world_size: int = 1,
|
| 397 |
-
cache_requests: bool = False,
|
| 398 |
-
rewrite_requests_cache: bool = False,
|
| 399 |
-
system_instruction: Optional[str] = None,
|
| 400 |
-
apply_chat_template: bool = False,
|
| 401 |
-
fewshot_as_multiturn: bool = False,
|
| 402 |
-
chat_template: Optional[Callable] = None,
|
| 403 |
-
tokenizer_name: str = "",
|
| 404 |
-
) -> None:
|
| 405 |
-
"""Build a set of Instances for a task, and store them in task.instances"""
|
| 406 |
-
|
| 407 |
-
# used with caching
|
| 408 |
-
og_limit = limit
|
| 409 |
-
|
| 410 |
-
cache_key = f"requests-{self._config.task}-{self.config.num_fewshot}shot-rank{rank}-world_size{world_size}"
|
| 411 |
-
cache_key += "-chat_template" if apply_chat_template else ""
|
| 412 |
-
cache_key += "-fewshot_as_multiturn" if fewshot_as_multiturn else ""
|
| 413 |
-
cache_key += (
|
| 414 |
-
f"-system_prompt_hash{utils.hash_string(system_instruction)}"
|
| 415 |
-
if system_instruction is not None
|
| 416 |
-
else ""
|
| 417 |
-
)
|
| 418 |
-
cache_key += f"-tokenizer{tokenizer_name}"
|
| 419 |
-
|
| 420 |
-
cached_instances = load_from_cache(file_name=cache_key, cache=cache_requests)
|
| 421 |
-
|
| 422 |
-
if cache_requests and cached_instances and not rewrite_requests_cache:
|
| 423 |
-
cached_instances = cached_instances[:limit]
|
| 424 |
-
|
| 425 |
-
flattened_instances = [
|
| 426 |
-
instance
|
| 427 |
-
for instance_group in cached_instances
|
| 428 |
-
for instance in instance_group
|
| 429 |
-
]
|
| 430 |
-
|
| 431 |
-
self._instances = flattened_instances
|
| 432 |
-
return
|
| 433 |
-
|
| 434 |
-
eval_logger.info(f"Building contexts for {self.config.task} on rank {rank}...")
|
| 435 |
-
|
| 436 |
-
instances = []
|
| 437 |
-
|
| 438 |
-
# process all documents when caching is specified for simplicity
|
| 439 |
-
if (
|
| 440 |
-
cache_requests
|
| 441 |
-
and (not cached_instances or rewrite_requests_cache)
|
| 442 |
-
and limit is not None
|
| 443 |
-
):
|
| 444 |
-
limit = None
|
| 445 |
-
|
| 446 |
-
doc_id_docs = list(
|
| 447 |
-
self.doc_iterator(
|
| 448 |
-
rank=rank, limit=limit, samples=samples, world_size=world_size
|
| 449 |
-
)
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
num_docs = len(doc_id_docs)
|
| 453 |
-
|
| 454 |
-
for doc_id, doc in tqdm(
|
| 455 |
-
doc_id_docs,
|
| 456 |
-
total=num_docs,
|
| 457 |
-
):
|
| 458 |
-
# sample fewshot context #TODO: need to offset doc_id by rank now!
|
| 459 |
-
fewshot_ctx = self.fewshot_context(
|
| 460 |
-
doc,
|
| 461 |
-
0 if self.config.num_fewshot is None else self.config.num_fewshot,
|
| 462 |
-
system_instruction,
|
| 463 |
-
apply_chat_template,
|
| 464 |
-
fewshot_as_multiturn,
|
| 465 |
-
chat_template,
|
| 466 |
-
gen_prefix=self.doc_to_prefix(doc),
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
|
| 470 |
-
inst = self.construct_requests(
|
| 471 |
-
doc=doc,
|
| 472 |
-
ctx=fewshot_ctx,
|
| 473 |
-
metadata=(self.config["task"], doc_id, self.config.repeats),
|
| 474 |
-
apply_chat_template=apply_chat_template,
|
| 475 |
-
chat_template=chat_template,
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
if not isinstance(inst, list):
|
| 479 |
-
inst = [inst]
|
| 480 |
-
|
| 481 |
-
instances.append(inst)
|
| 482 |
-
|
| 483 |
-
# now flatten, this is to allow slicing to work with pickles
|
| 484 |
-
|
| 485 |
-
sliced_instances = instances[:og_limit]
|
| 486 |
-
|
| 487 |
-
flattened_instances = [
|
| 488 |
-
instance
|
| 489 |
-
for instance_group in sliced_instances
|
| 490 |
-
for instance in instance_group
|
| 491 |
-
]
|
| 492 |
-
|
| 493 |
-
self._instances = flattened_instances
|
| 494 |
-
|
| 495 |
-
if len(self._instances) == 0:
|
| 496 |
-
raise ValueError("task.build_requests() did not find any docs!")
|
| 497 |
-
|
| 498 |
-
if cache_requests and (not cached_instances or rewrite_requests_cache):
|
| 499 |
-
save_to_cache(file_name=cache_key, obj=instances)
|
| 500 |
-
|
| 501 |
-
@abc.abstractmethod
|
| 502 |
-
def construct_requests(self, doc, ctx, **kwargs):
|
| 503 |
-
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 504 |
-
Requests which will be sent to the LM.
|
| 505 |
-
|
| 506 |
-
:param doc:
|
| 507 |
-
The document as returned from training_docs, validation_docs, or test_docs.
|
| 508 |
-
:param ctx: str
|
| 509 |
-
The context string, generated by fewshot_context. This includes the natural
|
| 510 |
-
language description, as well as the few shot examples, and the question
|
| 511 |
-
part of the document for `doc`.
|
| 512 |
-
:param doc_idx: int
|
| 513 |
-
The index of a document within `self.test_docs()` or `self.validation_docs()`,
|
| 514 |
-
whichever is the main split used.
|
| 515 |
-
:param repeats: int
|
| 516 |
-
TODO: update this docstring
|
| 517 |
-
The number of times each instance in a dataset is inferred on. Defaults to 1,
|
| 518 |
-
can be increased for techniques like majority voting.
|
| 519 |
-
"""
|
| 520 |
-
pass
|
| 521 |
-
|
| 522 |
-
@abc.abstractmethod
|
| 523 |
-
def process_results(self, doc, results):
|
| 524 |
-
"""Take a single document and the LM results and evaluates, returning a
|
| 525 |
-
dict where keys are the names of submetrics and values are the values of
|
| 526 |
-
the metric for that one document
|
| 527 |
-
|
| 528 |
-
:param doc:
|
| 529 |
-
The document as returned from training_docs, validation_docs, or test_docs.
|
| 530 |
-
:param results:
|
| 531 |
-
The results of the requests created in construct_requests.
|
| 532 |
-
"""
|
| 533 |
-
pass
|
| 534 |
-
|
| 535 |
-
@abc.abstractmethod
|
| 536 |
-
def aggregation(self):
|
| 537 |
-
"""
|
| 538 |
-
:returns: {str: [metric_score] -> float}
|
| 539 |
-
A dictionary where keys are the names of submetrics and values are
|
| 540 |
-
functions that aggregate a list of metric scores
|
| 541 |
-
"""
|
| 542 |
-
pass
|
| 543 |
-
|
| 544 |
-
@abc.abstractmethod
|
| 545 |
-
def higher_is_better(self):
|
| 546 |
-
"""
|
| 547 |
-
:returns: {str: bool}
|
| 548 |
-
A dictionary where keys are the names of submetrics and values are
|
| 549 |
-
whether a higher value of the submetric is better
|
| 550 |
-
"""
|
| 551 |
-
pass
|
| 552 |
-
|
| 553 |
-
def get_config(self, key: str) -> Any:
|
| 554 |
-
return getattr(self._config, key, None)
|
| 555 |
-
|
| 556 |
-
@classmethod
|
| 557 |
-
def count_bytes(cls, doc):
|
| 558 |
-
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
|
| 559 |
-
return len(doc.encode("utf-8"))
|
| 560 |
-
|
| 561 |
-
@classmethod
|
| 562 |
-
def count_words(cls, doc):
|
| 563 |
-
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
|
| 564 |
-
return len(re.split(r"\s+", doc))
|
| 565 |
-
|
| 566 |
-
@utils.positional_deprecated
|
| 567 |
-
def fewshot_context(self, doc, num_fewshot, rnd=None, description=None, **kwargs):
|
| 568 |
-
"""Returns a fewshot context string that is made up of a prepended description
|
| 569 |
-
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 570 |
-
|
| 571 |
-
:param doc: str
|
| 572 |
-
The document as returned from training_docs, validation_docs, or test_docs.
|
| 573 |
-
:param num_fewshot: int
|
| 574 |
-
The number of fewshot examples to provide in the returned context string.
|
| 575 |
-
:param rnd: random.Random
|
| 576 |
-
The pseudo-random number generator used to randomly sample examples.
|
| 577 |
-
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
| 578 |
-
:param description: str
|
| 579 |
-
The task's description that will be prepended to the fewshot examples.
|
| 580 |
-
:returns: str
|
| 581 |
-
The fewshot context.
|
| 582 |
-
"""
|
| 583 |
-
if rnd is None:
|
| 584 |
-
if self.fewshot_rnd is not None:
|
| 585 |
-
rnd = self.fewshot_rnd
|
| 586 |
-
else:
|
| 587 |
-
raise ValueError(
|
| 588 |
-
"A `random.Random` generator argument must be provided to `rnd`"
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
description = description if description else ""
|
| 592 |
-
|
| 593 |
-
if num_fewshot == 0:
|
| 594 |
-
labeled_examples = ""
|
| 595 |
-
else:
|
| 596 |
-
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
| 597 |
-
if self.has_training_docs():
|
| 598 |
-
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
|
| 599 |
-
else:
|
| 600 |
-
if self._fewshot_docs is None:
|
| 601 |
-
self._fewshot_docs = list(
|
| 602 |
-
self.validation_docs()
|
| 603 |
-
if self.has_validation_docs()
|
| 604 |
-
else self.test_docs()
|
| 605 |
-
)
|
| 606 |
-
|
| 607 |
-
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
| 608 |
-
|
| 609 |
-
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 610 |
-
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 611 |
-
|
| 612 |
-
labeled_examples = (
|
| 613 |
-
"\n\n".join(
|
| 614 |
-
[
|
| 615 |
-
self.doc_to_text(doc) + self.doc_to_target(doc)
|
| 616 |
-
for doc in fewshotex
|
| 617 |
-
]
|
| 618 |
-
)
|
| 619 |
-
+ "\n\n"
|
| 620 |
-
)
|
| 621 |
-
|
| 622 |
-
example = self.doc_to_text(doc)
|
| 623 |
-
return description + labeled_examples + example
|
| 624 |
-
|
| 625 |
-
def apply_filters(self) -> Optional[List[Instance]]:
|
| 626 |
-
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 627 |
-
if hasattr(self, "_filters"):
|
| 628 |
-
for f in self._filters:
|
| 629 |
-
f.apply(self._instances)
|
| 630 |
-
else:
|
| 631 |
-
eval_logger.warning("No filter defined, passing through instances")
|
| 632 |
-
return self._instances
|
| 633 |
-
|
| 634 |
-
def dump_config(self) -> dict:
|
| 635 |
-
"""Returns the config as a dictionary."""
|
| 636 |
-
# TODO: this should only return the overrides applied to a non-YAML task's configuration.
|
| 637 |
-
# (num_fewshot)
|
| 638 |
-
return self.config.to_dict()
|
| 639 |
-
|
| 640 |
-
def set_config(self, key: str, value: Any, update: bool = False) -> None:
|
| 641 |
-
"""Set or update the configuration for a given key."""
|
| 642 |
-
if key is None:
|
| 643 |
-
raise ValueError("Key must be provided.")
|
| 644 |
-
|
| 645 |
-
if update:
|
| 646 |
-
current_value = getattr(self._config, key, {})
|
| 647 |
-
if not isinstance(current_value, dict):
|
| 648 |
-
raise TypeError(
|
| 649 |
-
f"Expected a dict for key '{key}', got {type(current_value).__name__} instead."
|
| 650 |
-
)
|
| 651 |
-
current_value.update(value)
|
| 652 |
-
else:
|
| 653 |
-
setattr(self._config, key, value)
|
| 654 |
-
|
| 655 |
-
def override_metric(self, metric_name: str) -> None:
|
| 656 |
-
"""
|
| 657 |
-
Override the default metrics used for evaluation with custom metrics.
|
| 658 |
-
|
| 659 |
-
Parameters:
|
| 660 |
-
- metric_name (str): The name of the custom metric to override. Should be registered in api.metrics.
|
| 661 |
-
"""
|
| 662 |
-
(
|
| 663 |
-
self._metric_fn_list,
|
| 664 |
-
self._aggregation_list,
|
| 665 |
-
self._metric_fn_kwargs,
|
| 666 |
-
self._higher_is_better,
|
| 667 |
-
) = ({}, {}, {}, {})
|
| 668 |
-
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 669 |
-
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
|
| 670 |
-
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 671 |
-
self._metric_fn_kwargs[metric_name] = {}
|
| 672 |
-
if not isinstance(self, ConfigurableTask):
|
| 673 |
-
self.process_results = lambda x, y: {metric_name: get_metric(metric_name)}
|
| 674 |
-
self.aggregation = lambda: {
|
| 675 |
-
metric_name: get_metric_aggregation(metric_name)
|
| 676 |
-
}
|
| 677 |
-
setattr(self._config, "metric_list", [{"metric": metric_name}])
|
| 678 |
-
setattr(self._config, "process_results", None)
|
| 679 |
-
|
| 680 |
-
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
|
| 681 |
-
self.fewshot_rnd = random.Random(seed)
|
| 682 |
-
if hasattr(self, "sampler"):
|
| 683 |
-
self.sampler.rnd = self.fewshot_rnd
|
| 684 |
-
|
| 685 |
-
@property
|
| 686 |
-
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
|
| 687 |
-
if self.has_test_docs():
|
| 688 |
-
return self.test_docs()
|
| 689 |
-
elif self.has_validation_docs():
|
| 690 |
-
return self.validation_docs()
|
| 691 |
-
else:
|
| 692 |
-
raise ValueError(
|
| 693 |
-
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
|
| 694 |
-
)
|
| 695 |
-
|
| 696 |
-
def doc_iterator(
|
| 697 |
-
self,
|
| 698 |
-
*,
|
| 699 |
-
rank: int = 0,
|
| 700 |
-
limit: Union[int, None] = None,
|
| 701 |
-
world_size: int = 1,
|
| 702 |
-
samples: Optional[List[int]] = None,
|
| 703 |
-
) -> Iterator[Tuple[int, Any]]:
|
| 704 |
-
if samples:
|
| 705 |
-
n = len(self.eval_docs)
|
| 706 |
-
assert all([e < n for e in samples]), (
|
| 707 |
-
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
|
| 708 |
-
)
|
| 709 |
-
eval_logger.info(
|
| 710 |
-
f"{self.config.task}: Evaluating on {len(samples)} examples"
|
| 711 |
-
)
|
| 712 |
-
doc_iterator = utils.create_iterator(
|
| 713 |
-
enumerate(x for i, x in enumerate(self.eval_docs) if i in samples),
|
| 714 |
-
rank=int(rank),
|
| 715 |
-
limit=None, # limit does not matter here since we are selecting samples directly
|
| 716 |
-
world_size=int(world_size),
|
| 717 |
-
)
|
| 718 |
-
else:
|
| 719 |
-
limit = int(limit) if limit else None
|
| 720 |
-
doc_iterator = utils.create_iterator(
|
| 721 |
-
enumerate(self.eval_docs),
|
| 722 |
-
rank=int(rank),
|
| 723 |
-
limit=limit,
|
| 724 |
-
world_size=int(world_size),
|
| 725 |
-
)
|
| 726 |
-
return doc_iterator
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
class ConfigurableTask(Task):
|
| 730 |
-
VERSION = "Yaml"
|
| 731 |
-
OUTPUT_TYPE = None
|
| 732 |
-
CONFIG = None
|
| 733 |
-
|
| 734 |
-
def __init__(
|
| 735 |
-
self,
|
| 736 |
-
data_dir=None,
|
| 737 |
-
cache_dir=None,
|
| 738 |
-
download_mode=None,
|
| 739 |
-
config: Optional[dict] = None,
|
| 740 |
-
) -> None: # TODO no super() call here
|
| 741 |
-
# Get pre-configured attributes
|
| 742 |
-
self._config = self.CONFIG
|
| 743 |
-
|
| 744 |
-
# Use new configurations if there was no preconfiguration
|
| 745 |
-
if self.config is None:
|
| 746 |
-
self._config = TaskConfig(**config)
|
| 747 |
-
# Overwrite configs
|
| 748 |
-
else:
|
| 749 |
-
if config is not None:
|
| 750 |
-
self._config.__dict__.update(config)
|
| 751 |
-
|
| 752 |
-
if self.config is None:
|
| 753 |
-
raise ValueError(
|
| 754 |
-
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
|
| 755 |
-
)
|
| 756 |
-
|
| 757 |
-
if isinstance(self.config.metadata, dict):
|
| 758 |
-
if "version" in self.config.metadata:
|
| 759 |
-
self.VERSION = self.config.metadata["version"]
|
| 760 |
-
|
| 761 |
-
if self.config.output_type is not None:
|
| 762 |
-
if self.config.output_type not in ALL_OUTPUT_TYPES:
|
| 763 |
-
raise ValueError(
|
| 764 |
-
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
|
| 765 |
-
)
|
| 766 |
-
self.OUTPUT_TYPE = self.config.output_type
|
| 767 |
-
|
| 768 |
-
if self.config.doc_to_image is not None:
|
| 769 |
-
# mark the task as requiring multimodality.
|
| 770 |
-
self.MULTIMODAL = True
|
| 771 |
-
|
| 772 |
-
if self.config.doc_to_audio:
|
| 773 |
-
# mark the task as requiring multimodality.
|
| 774 |
-
self.MULTIMODAL = True
|
| 775 |
-
|
| 776 |
-
if self.config.unsafe_code is not False:
|
| 777 |
-
self.UNSAFE_CODE = True
|
| 778 |
-
|
| 779 |
-
if self.config.dataset_path is not None:
|
| 780 |
-
self.DATASET_PATH = self.config.dataset_path
|
| 781 |
-
|
| 782 |
-
if self.config.dataset_name is not None:
|
| 783 |
-
self.DATASET_NAME = self.config.dataset_name
|
| 784 |
-
|
| 785 |
-
self._metric_fn_list = {}
|
| 786 |
-
self._metric_fn_kwargs = {}
|
| 787 |
-
self._aggregation_list = {}
|
| 788 |
-
self._higher_is_better = {}
|
| 789 |
-
|
| 790 |
-
if self.config.metric_list is None:
|
| 791 |
-
# TODO: handle this in TaskConfig.__post_init__ ?
|
| 792 |
-
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
|
| 793 |
-
|
| 794 |
-
for metric_name in _metric_list:
|
| 795 |
-
self._metric_fn_list[metric_name] = get_metric(metric_name)
|
| 796 |
-
self._metric_fn_kwargs[metric_name] = {}
|
| 797 |
-
self._aggregation_list[metric_name] = get_metric_aggregation(
|
| 798 |
-
metric_name
|
| 799 |
-
)
|
| 800 |
-
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 801 |
-
else:
|
| 802 |
-
for metric_config in self.config.metric_list:
|
| 803 |
-
if "metric" not in metric_config:
|
| 804 |
-
raise ValueError(
|
| 805 |
-
"'metric' key not provided for an entry in 'metric_list', must be specified!"
|
| 806 |
-
)
|
| 807 |
-
metric_name = metric_config["metric"]
|
| 808 |
-
kwargs = {
|
| 809 |
-
key: metric_config[key]
|
| 810 |
-
for key in metric_config
|
| 811 |
-
if key
|
| 812 |
-
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
|
| 813 |
-
}
|
| 814 |
-
hf_evaluate_metric = (
|
| 815 |
-
"hf_evaluate" in metric_config
|
| 816 |
-
and metric_config["hf_evaluate"] is True
|
| 817 |
-
)
|
| 818 |
-
|
| 819 |
-
if self.config.process_results is not None:
|
| 820 |
-
self._metric_fn_list[metric_name] = None
|
| 821 |
-
self._metric_fn_kwargs[metric_name] = {}
|
| 822 |
-
elif callable(metric_name):
|
| 823 |
-
metric_fn = metric_name.__call__
|
| 824 |
-
metric_name = metric_name.__name__
|
| 825 |
-
self._metric_fn_list[metric_name] = metric_fn
|
| 826 |
-
self._metric_fn_kwargs[metric_name] = kwargs
|
| 827 |
-
else:
|
| 828 |
-
self._metric_fn_list[metric_name] = get_metric(
|
| 829 |
-
metric_name, hf_evaluate_metric
|
| 830 |
-
)
|
| 831 |
-
self._metric_fn_kwargs[metric_name] = kwargs
|
| 832 |
-
|
| 833 |
-
if "aggregation" in metric_config:
|
| 834 |
-
agg_name = metric_config["aggregation"]
|
| 835 |
-
if isinstance(agg_name, str):
|
| 836 |
-
self._aggregation_list[metric_name] = get_aggregation(agg_name)
|
| 837 |
-
elif callable(agg_name): # noqa: E721
|
| 838 |
-
self._aggregation_list[metric_name] = metric_config[
|
| 839 |
-
"aggregation"
|
| 840 |
-
]
|
| 841 |
-
else:
|
| 842 |
-
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
|
| 843 |
-
metric_agg = get_metric_aggregation(metric_name)
|
| 844 |
-
eval_logger.warning(
|
| 845 |
-
f"[Task: {self.config.task}] metric {metric_name} is defined, but aggregation is not. "
|
| 846 |
-
f"using default "
|
| 847 |
-
f"aggregation={INV_AGG_REGISTRY[metric_agg]}"
|
| 848 |
-
)
|
| 849 |
-
self._aggregation_list[metric_name] = metric_agg
|
| 850 |
-
|
| 851 |
-
if "higher_is_better" in metric_config:
|
| 852 |
-
self._higher_is_better[metric_name] = metric_config[
|
| 853 |
-
"higher_is_better"
|
| 854 |
-
]
|
| 855 |
-
else:
|
| 856 |
-
eval_logger.warning(
|
| 857 |
-
f"[Task: {self.config.task}] metric {metric_name} is defined, but higher_is_better is not. "
|
| 858 |
-
f"using default "
|
| 859 |
-
f"higher_is_better={is_higher_better(metric_name)}"
|
| 860 |
-
)
|
| 861 |
-
self._higher_is_better[metric_name] = is_higher_better(metric_name)
|
| 862 |
-
|
| 863 |
-
self.download(self.config.dataset_kwargs)
|
| 864 |
-
self._training_docs = None
|
| 865 |
-
self._fewshot_docs = None
|
| 866 |
-
|
| 867 |
-
if self.config.filter_list is not None:
|
| 868 |
-
self._filters = []
|
| 869 |
-
for filter_config in self.config.filter_list:
|
| 870 |
-
filter_name = filter_config["name"]
|
| 871 |
-
filter_functions = filter_config["filter"]
|
| 872 |
-
components = []
|
| 873 |
-
for function in filter_functions:
|
| 874 |
-
kwargs = {
|
| 875 |
-
key: function[key] for key in function if key != "function"
|
| 876 |
-
}
|
| 877 |
-
components.append([function["function"], kwargs])
|
| 878 |
-
filter_pipeline = build_filter_ensemble(filter_name, components)
|
| 879 |
-
self._filters.append(filter_pipeline)
|
| 880 |
-
else:
|
| 881 |
-
# TODO: handle repeats in a more general way rather than just discarding
|
| 882 |
-
eval_logger.debug(
|
| 883 |
-
"No custom filters defined. Using default 'take_first' filter for handling repeats."
|
| 884 |
-
)
|
| 885 |
-
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
|
| 886 |
-
|
| 887 |
-
if self.config.use_prompt is not None:
|
| 888 |
-
eval_logger.info(f"loading prompt {self.config.use_prompt}")
|
| 889 |
-
self.prompt = get_prompt(
|
| 890 |
-
self.config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
|
| 891 |
-
)
|
| 892 |
-
else:
|
| 893 |
-
self.prompt = None
|
| 894 |
-
|
| 895 |
-
if self.fewshot_docs() is not None:
|
| 896 |
-
self.fewshot_rnd = (
|
| 897 |
-
random.Random()
|
| 898 |
-
) # setting with no seed, to be overridden at a later time
|
| 899 |
-
config_sampler: Union[str, Callable] = (
|
| 900 |
-
self.config.fewshot_config.get("sampler", "default")
|
| 901 |
-
if self.config.fewshot_config
|
| 902 |
-
else "default"
|
| 903 |
-
)
|
| 904 |
-
if isinstance(config_sampler, str):
|
| 905 |
-
self.sampler = samplers.get_sampler(config_sampler)(
|
| 906 |
-
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
|
| 907 |
-
)
|
| 908 |
-
elif callable(config_sampler) and issubclass(
|
| 909 |
-
config_sampler, samplers.ContextSampler
|
| 910 |
-
):
|
| 911 |
-
self.sampler = config_sampler(
|
| 912 |
-
docs=list(self.fewshot_docs()), task=self, rnd=self.fewshot_rnd
|
| 913 |
-
)
|
| 914 |
-
else:
|
| 915 |
-
raise TypeError(
|
| 916 |
-
f"fewshot_config.sampler should be a string or callable of ContextSampler type, "
|
| 917 |
-
f"not {type(config_sampler)}"
|
| 918 |
-
)
|
| 919 |
-
|
| 920 |
-
self.task_docs = self.eval_docs
|
| 921 |
-
|
| 922 |
-
# Test One Doc
|
| 923 |
-
self.features = list(self.task_docs.features.keys())
|
| 924 |
-
self.multiple_input = 0
|
| 925 |
-
self.multiple_target = 0
|
| 926 |
-
test_doc = self.task_docs[0]
|
| 927 |
-
test_text = self.doc_to_text(test_doc)
|
| 928 |
-
test_target = self.doc_to_target(test_doc)
|
| 929 |
-
|
| 930 |
-
if self.config.doc_to_choice is not None:
|
| 931 |
-
test_choice = self.doc_to_choice(test_doc)
|
| 932 |
-
if not isinstance(test_choice, list):
|
| 933 |
-
eval_logger.error("doc_to_choice must return list")
|
| 934 |
-
else:
|
| 935 |
-
num_choice = len(test_choice)
|
| 936 |
-
|
| 937 |
-
if isinstance(test_text, int):
|
| 938 |
-
eval_logger.debug(
|
| 939 |
-
"doc_to_text returned an int. Assuming multiple inputs."
|
| 940 |
-
)
|
| 941 |
-
self.multiple_input = num_choice
|
| 942 |
-
else:
|
| 943 |
-
test_choice = None
|
| 944 |
-
|
| 945 |
-
if isinstance(test_target, list):
|
| 946 |
-
eval_logger.debug(
|
| 947 |
-
"doc_to_target returned a list. Assuming multiple targets."
|
| 948 |
-
)
|
| 949 |
-
self.multiple_target = len(test_target)
|
| 950 |
-
else:
|
| 951 |
-
if (isinstance(test_target, int)) and (test_choice is not None):
|
| 952 |
-
test_target = test_choice[test_target]
|
| 953 |
-
else:
|
| 954 |
-
test_target = str(test_target)
|
| 955 |
-
|
| 956 |
-
if test_choice is not None:
|
| 957 |
-
check_choices = test_choice
|
| 958 |
-
else:
|
| 959 |
-
check_choices = [test_target]
|
| 960 |
-
if self.config.doc_to_choice is not None:
|
| 961 |
-
for choice in check_choices:
|
| 962 |
-
choice_has_whitespace = True if choice[0].isspace() else False
|
| 963 |
-
delimiter_has_whitespace = (
|
| 964 |
-
True
|
| 965 |
-
if self.config.target_delimiter.rstrip()
|
| 966 |
-
!= self.config.target_delimiter
|
| 967 |
-
else False
|
| 968 |
-
)
|
| 969 |
-
|
| 970 |
-
if delimiter_has_whitespace and choice_has_whitespace:
|
| 971 |
-
eval_logger.debug(
|
| 972 |
-
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" have whitespace'
|
| 973 |
-
)
|
| 974 |
-
elif (not delimiter_has_whitespace) and (not choice_has_whitespace):
|
| 975 |
-
eval_logger.debug(
|
| 976 |
-
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
|
| 977 |
-
)
|
| 978 |
-
|
| 979 |
-
def download(
|
| 980 |
-
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
|
| 981 |
-
) -> None:
|
| 982 |
-
if isinstance(self.config.custom_dataset, Callable):
|
| 983 |
-
eval_logger.warning(
|
| 984 |
-
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
|
| 985 |
-
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
|
| 986 |
-
)
|
| 987 |
-
self.dataset = self.config.custom_dataset(
|
| 988 |
-
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
|
| 989 |
-
)
|
| 990 |
-
else:
|
| 991 |
-
self.dataset = datasets.load_dataset(
|
| 992 |
-
path=self.DATASET_PATH,
|
| 993 |
-
name=self.DATASET_NAME,
|
| 994 |
-
**dataset_kwargs if dataset_kwargs is not None else {},
|
| 995 |
-
)
|
| 996 |
-
|
| 997 |
-
def has_training_docs(self) -> bool:
|
| 998 |
-
if self.config.training_split is not None:
|
| 999 |
-
return True
|
| 1000 |
-
else:
|
| 1001 |
-
return False
|
| 1002 |
-
|
| 1003 |
-
def has_validation_docs(self) -> bool:
|
| 1004 |
-
if self.config.validation_split is not None:
|
| 1005 |
-
return True
|
| 1006 |
-
else:
|
| 1007 |
-
return False
|
| 1008 |
-
|
| 1009 |
-
def has_test_docs(self) -> bool:
|
| 1010 |
-
if self.config.test_split is not None:
|
| 1011 |
-
return True
|
| 1012 |
-
else:
|
| 1013 |
-
return False
|
| 1014 |
-
|
| 1015 |
-
def training_docs(self) -> datasets.Dataset:
|
| 1016 |
-
if self.has_training_docs():
|
| 1017 |
-
if self.config.process_docs is not None:
|
| 1018 |
-
return self.config.process_docs(
|
| 1019 |
-
self.dataset[self.config.training_split]
|
| 1020 |
-
)
|
| 1021 |
-
return self.dataset[self.config.training_split]
|
| 1022 |
-
|
| 1023 |
-
def validation_docs(self) -> datasets.Dataset:
|
| 1024 |
-
if self.has_validation_docs():
|
| 1025 |
-
if self.config.process_docs is not None:
|
| 1026 |
-
return self.config.process_docs(
|
| 1027 |
-
self.dataset[self.config.validation_split]
|
| 1028 |
-
)
|
| 1029 |
-
return self.dataset[self.config.validation_split]
|
| 1030 |
-
|
| 1031 |
-
def test_docs(self) -> datasets.Dataset:
|
| 1032 |
-
if self.has_test_docs():
|
| 1033 |
-
if self.config.process_docs is not None:
|
| 1034 |
-
return self.config.process_docs(self.dataset[self.config.test_split])
|
| 1035 |
-
return self.dataset[self.config.test_split]
|
| 1036 |
-
|
| 1037 |
-
def fewshot_docs(self):
|
| 1038 |
-
if self.config.fewshot_split is not None:
|
| 1039 |
-
if self.config.process_docs is not None:
|
| 1040 |
-
return self.config.process_docs(self.dataset[self.config.fewshot_split])
|
| 1041 |
-
return self.dataset[self.config.fewshot_split]
|
| 1042 |
-
elif (
|
| 1043 |
-
self.config.fewshot_config is not None
|
| 1044 |
-
and self.config.fewshot_config.get("samples", None) is not None
|
| 1045 |
-
):
|
| 1046 |
-
if isinstance(self.config.fewshot_config["samples"], list):
|
| 1047 |
-
return self.config.fewshot_config["samples"]
|
| 1048 |
-
elif callable(self.config.fewshot_config["samples"]):
|
| 1049 |
-
return self.config.fewshot_config["samples"]()
|
| 1050 |
-
else:
|
| 1051 |
-
raise Exception(
|
| 1052 |
-
"`fewshot_config['samples']` was incorrectly defined in the configuration. It should be either a list of samples as a dict, or function returning this list."
|
| 1053 |
-
)
|
| 1054 |
-
else:
|
| 1055 |
-
if (self.config.num_fewshot is not None) and (self.config.num_fewshot > 0):
|
| 1056 |
-
eval_logger.warning(
|
| 1057 |
-
f"[Task: {self.config.task}] "
|
| 1058 |
-
"num_fewshot > 0 but fewshot_split is None. "
|
| 1059 |
-
"using preconfigured rule."
|
| 1060 |
-
)
|
| 1061 |
-
return super().fewshot_docs()
|
| 1062 |
-
|
| 1063 |
-
@staticmethod
|
| 1064 |
-
def append_target_question(
|
| 1065 |
-
labeled_examples: List[Dict[str, str]],
|
| 1066 |
-
question: str,
|
| 1067 |
-
fewshot_as_multiturn: bool = False,
|
| 1068 |
-
gen_prefix: Optional[str] = None,
|
| 1069 |
-
) -> None:
|
| 1070 |
-
"""Adds a target question to the labeled examples list.
|
| 1071 |
-
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
|
| 1072 |
-
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
|
| 1073 |
-
"""
|
| 1074 |
-
if not fewshot_as_multiturn:
|
| 1075 |
-
# if no messages or last message is system, append as new user entry
|
| 1076 |
-
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
|
| 1077 |
-
labeled_examples.append({"role": "user", "content": question})
|
| 1078 |
-
# if last message is user, append to it to avoid two user messages in a row
|
| 1079 |
-
else:
|
| 1080 |
-
labeled_examples[-1]["content"] += question
|
| 1081 |
-
else:
|
| 1082 |
-
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
|
| 1083 |
-
labeled_examples.append({"role": "user", "content": question})
|
| 1084 |
-
if gen_prefix:
|
| 1085 |
-
labeled_examples.append({"role": "assistant", "content": gen_prefix})
|
| 1086 |
-
|
| 1087 |
-
@utils.positional_deprecated
|
| 1088 |
-
def fewshot_context(
|
| 1089 |
-
self,
|
| 1090 |
-
doc: dict,
|
| 1091 |
-
num_fewshot: int,
|
| 1092 |
-
system_instruction: Optional[str] = None,
|
| 1093 |
-
apply_chat_template: bool = False,
|
| 1094 |
-
fewshot_as_multiturn: bool = False,
|
| 1095 |
-
chat_template: Optional[Callable] = None,
|
| 1096 |
-
gen_prefix: Optional[str] = None,
|
| 1097 |
-
) -> Union[str, List[str]]:
|
| 1098 |
-
"""Returns a fewshot context string that is made up of a prepended description
|
| 1099 |
-
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 1100 |
-
|
| 1101 |
-
:param doc: str
|
| 1102 |
-
The document as returned from training_docs, validation_docs, or test_docs.
|
| 1103 |
-
:param num_fewshot: int
|
| 1104 |
-
The number of fewshot examples to provide in the returned context string.
|
| 1105 |
-
:param system_instruction: str
|
| 1106 |
-
System instruction to be applied to the prompt.
|
| 1107 |
-
:param apply_chat_template: bool
|
| 1108 |
-
Whether to apply the chat template to the fewshot context.
|
| 1109 |
-
:param fewshot_as_multiturn: bool
|
| 1110 |
-
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 1111 |
-
:param chat_template:
|
| 1112 |
-
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
|
| 1113 |
-
:param gen_prefix:
|
| 1114 |
-
String to append after the <|assistant|> token.
|
| 1115 |
-
:returns: str
|
| 1116 |
-
The fewshot context.
|
| 1117 |
-
"""
|
| 1118 |
-
if apply_chat_template:
|
| 1119 |
-
labeled_examples = []
|
| 1120 |
-
else:
|
| 1121 |
-
labeled_examples = ""
|
| 1122 |
-
|
| 1123 |
-
# get task description
|
| 1124 |
-
if description := self.config.description:
|
| 1125 |
-
description = utils.apply_template(self.config.description, doc)
|
| 1126 |
-
|
| 1127 |
-
# create system prompt based on the provided system instruction and description
|
| 1128 |
-
if system_instruction is not None and description:
|
| 1129 |
-
system_prompt = (
|
| 1130 |
-
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
|
| 1131 |
-
)
|
| 1132 |
-
elif system_instruction is not None:
|
| 1133 |
-
system_prompt = system_instruction
|
| 1134 |
-
elif description:
|
| 1135 |
-
system_prompt = description
|
| 1136 |
-
else:
|
| 1137 |
-
system_prompt = ""
|
| 1138 |
-
|
| 1139 |
-
# add system prompt if specified
|
| 1140 |
-
if system_prompt:
|
| 1141 |
-
if apply_chat_template:
|
| 1142 |
-
labeled_examples.append({"role": "system", "content": system_prompt})
|
| 1143 |
-
else:
|
| 1144 |
-
labeled_examples = system_prompt
|
| 1145 |
-
# if few-shot - append examples after the system prompt
|
| 1146 |
-
if num_fewshot > 0:
|
| 1147 |
-
if apply_chat_template:
|
| 1148 |
-
labeled_examples.extend(
|
| 1149 |
-
self.sampler.get_chat_context(
|
| 1150 |
-
doc,
|
| 1151 |
-
num_fewshot,
|
| 1152 |
-
fewshot_as_multiturn,
|
| 1153 |
-
gen_prefix=gen_prefix,
|
| 1154 |
-
)
|
| 1155 |
-
)
|
| 1156 |
-
else:
|
| 1157 |
-
labeled_examples += self.sampler.get_context(
|
| 1158 |
-
doc, num_fewshot, gen_prefix=gen_prefix
|
| 1159 |
-
)
|
| 1160 |
-
|
| 1161 |
-
example = self.doc_to_text(doc)
|
| 1162 |
-
if apply_chat_template:
|
| 1163 |
-
if self.multiple_input:
|
| 1164 |
-
# TODO: append prefill?
|
| 1165 |
-
if not labeled_examples:
|
| 1166 |
-
return ""
|
| 1167 |
-
return chat_template(labeled_examples)
|
| 1168 |
-
if isinstance(example, str):
|
| 1169 |
-
self.append_target_question(
|
| 1170 |
-
labeled_examples,
|
| 1171 |
-
example,
|
| 1172 |
-
fewshot_as_multiturn,
|
| 1173 |
-
gen_prefix=gen_prefix,
|
| 1174 |
-
)
|
| 1175 |
-
# for loglikelihood create a list of questions with appended choices
|
| 1176 |
-
elif isinstance(example, list):
|
| 1177 |
-
labeled_examples_list = []
|
| 1178 |
-
# copy chat history for each example and append the answer
|
| 1179 |
-
for ex in example:
|
| 1180 |
-
chat = deepcopy(labeled_examples)
|
| 1181 |
-
self.append_target_question(
|
| 1182 |
-
chat,
|
| 1183 |
-
ex,
|
| 1184 |
-
fewshot_as_multiturn,
|
| 1185 |
-
gen_prefix=gen_prefix,
|
| 1186 |
-
)
|
| 1187 |
-
# TODO: append prefill?
|
| 1188 |
-
labeled_examples_list.append(
|
| 1189 |
-
chat_template(
|
| 1190 |
-
chat,
|
| 1191 |
-
add_generation_prompt=False if gen_prefix else True,
|
| 1192 |
-
)
|
| 1193 |
-
)
|
| 1194 |
-
return labeled_examples_list
|
| 1195 |
-
# if example is an integer, append the choice or convert to string
|
| 1196 |
-
elif isinstance(example, int):
|
| 1197 |
-
if self.config.doc_to_choice is not None:
|
| 1198 |
-
choices = self.doc_to_choice(doc)
|
| 1199 |
-
self.append_target_question(
|
| 1200 |
-
labeled_examples,
|
| 1201 |
-
choices[example],
|
| 1202 |
-
fewshot_as_multiturn,
|
| 1203 |
-
gen_prefix=gen_prefix,
|
| 1204 |
-
)
|
| 1205 |
-
else:
|
| 1206 |
-
self.append_target_question(
|
| 1207 |
-
labeled_examples,
|
| 1208 |
-
str(example),
|
| 1209 |
-
fewshot_as_multiturn,
|
| 1210 |
-
gen_prefix=gen_prefix,
|
| 1211 |
-
)
|
| 1212 |
-
# return lm.apply_chat_template(labeled_examples)
|
| 1213 |
-
return chat_template(
|
| 1214 |
-
labeled_examples,
|
| 1215 |
-
add_generation_prompt=False if gen_prefix else True,
|
| 1216 |
-
)
|
| 1217 |
-
else:
|
| 1218 |
-
prefix = (
|
| 1219 |
-
self.config.target_delimiter + gen_prefix
|
| 1220 |
-
if gen_prefix is not None
|
| 1221 |
-
else ""
|
| 1222 |
-
)
|
| 1223 |
-
if self.multiple_input:
|
| 1224 |
-
return labeled_examples
|
| 1225 |
-
if isinstance(example, str):
|
| 1226 |
-
return labeled_examples + example + prefix
|
| 1227 |
-
elif isinstance(example, list):
|
| 1228 |
-
return [labeled_examples + ex + prefix for ex in example]
|
| 1229 |
-
elif isinstance(example, int):
|
| 1230 |
-
if self.config.doc_to_choice is not None:
|
| 1231 |
-
choices = self.doc_to_choice(doc)
|
| 1232 |
-
return labeled_examples + choices[example] + prefix
|
| 1233 |
-
else:
|
| 1234 |
-
return labeled_examples + str(example) + prefix
|
| 1235 |
-
|
| 1236 |
-
def apply_filters(self) -> Optional[List[Instance]]:
|
| 1237 |
-
"""Iterates over FilterEnsembles and applies them to instances"""
|
| 1238 |
-
if hasattr(self, "_filters"):
|
| 1239 |
-
for f in self._filters:
|
| 1240 |
-
f.apply(self._instances)
|
| 1241 |
-
else:
|
| 1242 |
-
eval_logger.warning("No filter defined, passing through instances")
|
| 1243 |
-
return self._instances
|
| 1244 |
-
|
| 1245 |
-
def should_decontaminate(self):
|
| 1246 |
-
return self.config.should_decontaminate
|
| 1247 |
-
|
| 1248 |
-
def doc_to_decontamination_query(self, doc: dict):
|
| 1249 |
-
if self.config.should_decontaminate:
|
| 1250 |
-
if self.config.doc_to_decontamination_query is None:
|
| 1251 |
-
return self.doc_to_text(doc)
|
| 1252 |
-
else:
|
| 1253 |
-
doc_to_decontamination_query = self.config.doc_to_decontamination_query
|
| 1254 |
-
if doc_to_decontamination_query in self.features:
|
| 1255 |
-
return doc[doc_to_decontamination_query]
|
| 1256 |
-
elif callable(doc_to_decontamination_query):
|
| 1257 |
-
return doc_to_decontamination_query(doc)
|
| 1258 |
-
else:
|
| 1259 |
-
return ast.literal_eval(
|
| 1260 |
-
utils.apply_template(
|
| 1261 |
-
self.config.doc_to_decontamination_query, doc
|
| 1262 |
-
)
|
| 1263 |
-
)
|
| 1264 |
-
|
| 1265 |
-
def _process_doc(self, doc: dict) -> dict:
|
| 1266 |
-
"""
|
| 1267 |
-
Override this to process (detokenize, strip, replace, etc.) individual
|
| 1268 |
-
documents. This can be used in a map over documents of a data split.
|
| 1269 |
-
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 1270 |
-
|
| 1271 |
-
:return: dict
|
| 1272 |
-
The processed version of the specified `doc`.
|
| 1273 |
-
"""
|
| 1274 |
-
return doc
|
| 1275 |
-
|
| 1276 |
-
def doc_to_text(self, doc, doc_to_text=None):
|
| 1277 |
-
if self.prompt is not None:
|
| 1278 |
-
doc_to_text = self.prompt
|
| 1279 |
-
elif doc_to_text is not None:
|
| 1280 |
-
doc_to_text = doc_to_text
|
| 1281 |
-
else:
|
| 1282 |
-
doc_to_text = self.config.doc_to_text
|
| 1283 |
-
|
| 1284 |
-
if isinstance(doc_to_text, int):
|
| 1285 |
-
return doc_to_text
|
| 1286 |
-
elif isinstance(doc_to_text, str):
|
| 1287 |
-
if doc_to_text in self.features:
|
| 1288 |
-
# if self.config.doc_to_choice is not None:
|
| 1289 |
-
# return self.doc_to_choice(doc)[doc[doc_to_text]]
|
| 1290 |
-
# else:
|
| 1291 |
-
return doc[doc_to_text]
|
| 1292 |
-
else:
|
| 1293 |
-
text_string = utils.apply_template(doc_to_text, doc)
|
| 1294 |
-
if text_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1295 |
-
return ast.literal_eval(text_string)
|
| 1296 |
-
else:
|
| 1297 |
-
return text_string
|
| 1298 |
-
elif callable(doc_to_text):
|
| 1299 |
-
return doc_to_text(doc)
|
| 1300 |
-
# Used when applying a Promptsource template
|
| 1301 |
-
elif hasattr(doc_to_text, "apply"):
|
| 1302 |
-
applied_prompt = doc_to_text.apply(doc)
|
| 1303 |
-
if len(applied_prompt) == 2:
|
| 1304 |
-
return applied_prompt[0]
|
| 1305 |
-
else:
|
| 1306 |
-
eval_logger.warning("Applied prompt returns empty string")
|
| 1307 |
-
return self.config.fewshot_delimiter
|
| 1308 |
-
else:
|
| 1309 |
-
print(type(doc_to_text))
|
| 1310 |
-
raise TypeError
|
| 1311 |
-
|
| 1312 |
-
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
|
| 1313 |
-
if self.prompt is not None:
|
| 1314 |
-
doc_to_target = self.prompt
|
| 1315 |
-
elif doc_to_target is not None:
|
| 1316 |
-
doc_to_target = doc_to_target
|
| 1317 |
-
else:
|
| 1318 |
-
doc_to_target = self.config.doc_to_target
|
| 1319 |
-
|
| 1320 |
-
if isinstance(doc_to_target, int):
|
| 1321 |
-
return doc_to_target
|
| 1322 |
-
elif isinstance(doc_to_target, str):
|
| 1323 |
-
if doc_to_target in self.features:
|
| 1324 |
-
# if self.config.doc_to_choice is not None:
|
| 1325 |
-
# return self.doc_to_choice(doc)[doc[doc_to_target]]
|
| 1326 |
-
# else:
|
| 1327 |
-
return doc[doc_to_target]
|
| 1328 |
-
else:
|
| 1329 |
-
target_string = utils.apply_template(doc_to_target, doc)
|
| 1330 |
-
if target_string.isdigit() and self._config.doc_to_choice is not None:
|
| 1331 |
-
return ast.literal_eval(target_string)
|
| 1332 |
-
elif (
|
| 1333 |
-
len(target_string) >= 2
|
| 1334 |
-
and (target_string[0] == "[")
|
| 1335 |
-
and (target_string[-1] == "]")
|
| 1336 |
-
):
|
| 1337 |
-
try:
|
| 1338 |
-
return ast.literal_eval(target_string)
|
| 1339 |
-
except (SyntaxError, ValueError):
|
| 1340 |
-
return target_string
|
| 1341 |
-
else:
|
| 1342 |
-
return target_string
|
| 1343 |
-
elif isinstance(doc_to_target, list):
|
| 1344 |
-
return doc_to_target
|
| 1345 |
-
elif callable(doc_to_target):
|
| 1346 |
-
return doc_to_target(doc)
|
| 1347 |
-
# Used when applying a Promptsource template
|
| 1348 |
-
elif hasattr(doc_to_target, "apply"):
|
| 1349 |
-
applied_prompt = doc_to_target.apply(doc)
|
| 1350 |
-
if len(applied_prompt) == 2:
|
| 1351 |
-
return applied_prompt[1]
|
| 1352 |
-
else:
|
| 1353 |
-
eval_logger.warning("Applied prompt returns empty string")
|
| 1354 |
-
return self.config.fewshot_delimiter
|
| 1355 |
-
else:
|
| 1356 |
-
raise TypeError
|
| 1357 |
-
|
| 1358 |
-
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
|
| 1359 |
-
if self.prompt is not None:
|
| 1360 |
-
doc_to_choice = self.prompt
|
| 1361 |
-
elif doc_to_choice is not None:
|
| 1362 |
-
doc_to_choice = doc_to_choice
|
| 1363 |
-
elif self.config.doc_to_choice is None:
|
| 1364 |
-
eval_logger.error("doc_to_choice was called but not set in config")
|
| 1365 |
-
else:
|
| 1366 |
-
doc_to_choice = self.config.doc_to_choice
|
| 1367 |
-
|
| 1368 |
-
if isinstance(doc_to_choice, str):
|
| 1369 |
-
if doc_to_choice in self.features:
|
| 1370 |
-
return doc[doc_to_choice]
|
| 1371 |
-
else:
|
| 1372 |
-
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
|
| 1373 |
-
elif isinstance(doc_to_choice, list):
|
| 1374 |
-
return doc_to_choice
|
| 1375 |
-
elif isinstance(doc_to_choice, dict):
|
| 1376 |
-
return list(doc_to_choice.values())
|
| 1377 |
-
elif callable(doc_to_choice):
|
| 1378 |
-
return doc_to_choice(doc)
|
| 1379 |
-
elif hasattr(doc_to_choice, "get_answer_choices_list"):
|
| 1380 |
-
return doc_to_choice.get_answer_choices_list(doc)
|
| 1381 |
-
else:
|
| 1382 |
-
raise TypeError
|
| 1383 |
-
|
| 1384 |
-
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
|
| 1385 |
-
if doc_to_image is not None:
|
| 1386 |
-
doc_to_image = doc_to_image
|
| 1387 |
-
elif self.config.doc_to_image is not None:
|
| 1388 |
-
doc_to_image = self.config.doc_to_image
|
| 1389 |
-
else:
|
| 1390 |
-
return None
|
| 1391 |
-
|
| 1392 |
-
if isinstance(doc_to_image, list):
|
| 1393 |
-
image_feature = [
|
| 1394 |
-
self.doc_to_image(doc, feature) for feature in doc_to_image
|
| 1395 |
-
]
|
| 1396 |
-
return [feature for feature in image_feature if feature is not None]
|
| 1397 |
-
elif isinstance(doc_to_image, str):
|
| 1398 |
-
if doc_to_image in self.features:
|
| 1399 |
-
return doc[doc_to_image]
|
| 1400 |
-
else:
|
| 1401 |
-
return ast.literal_eval(utils.apply_template(doc_to_image, doc))
|
| 1402 |
-
elif callable(doc_to_image):
|
| 1403 |
-
return doc_to_image(doc)
|
| 1404 |
-
else:
|
| 1405 |
-
return None
|
| 1406 |
-
|
| 1407 |
-
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
|
| 1408 |
-
if doc_to_audio is not None:
|
| 1409 |
-
doc_to_audio = doc_to_audio
|
| 1410 |
-
elif self.config.doc_to_audio is not None:
|
| 1411 |
-
doc_to_audio = self.config.doc_to_audio
|
| 1412 |
-
else:
|
| 1413 |
-
return None
|
| 1414 |
-
|
| 1415 |
-
if isinstance(doc_to_audio, list):
|
| 1416 |
-
audio_feature = [
|
| 1417 |
-
self.doc_to_audio(doc, feature) for feature in doc_to_audio
|
| 1418 |
-
]
|
| 1419 |
-
return [feature for feature in audio_feature if feature is not None]
|
| 1420 |
-
elif isinstance(doc_to_audio, str):
|
| 1421 |
-
if doc_to_audio in self.features:
|
| 1422 |
-
return doc[doc_to_audio]
|
| 1423 |
-
else:
|
| 1424 |
-
return ast.literal_eval(utils.apply_template(doc_to_audio, doc))
|
| 1425 |
-
elif callable(doc_to_audio):
|
| 1426 |
-
return doc_to_audio(doc)
|
| 1427 |
-
else:
|
| 1428 |
-
return None
|
| 1429 |
-
|
| 1430 |
-
def doc_to_prefix(self, doc):
|
| 1431 |
-
if (gen_prefix := self.config.gen_prefix) is not None:
|
| 1432 |
-
if gen_prefix in self.features:
|
| 1433 |
-
return doc[gen_prefix]
|
| 1434 |
-
else:
|
| 1435 |
-
return utils.apply_template(gen_prefix, doc)
|
| 1436 |
-
return None
|
| 1437 |
-
|
| 1438 |
-
def construct_requests(
|
| 1439 |
-
self, doc: dict, ctx: str, **kwargs
|
| 1440 |
-
) -> Union[List[Instance], Instance]:
|
| 1441 |
-
apply_chat_template = kwargs.pop("apply_chat_template", False)
|
| 1442 |
-
chat_template: Callable | None = kwargs.pop("chat_template", None)
|
| 1443 |
-
|
| 1444 |
-
aux_arguments = None
|
| 1445 |
-
|
| 1446 |
-
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1447 |
-
arguments = (ctx, self.doc_to_target(doc))
|
| 1448 |
-
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1449 |
-
arguments = (self.doc_to_target(doc),)
|
| 1450 |
-
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1451 |
-
choices = self.doc_to_choice(doc)
|
| 1452 |
-
target_delimiter = self.config.target_delimiter
|
| 1453 |
-
if apply_chat_template:
|
| 1454 |
-
target_delimiter = ""
|
| 1455 |
-
if self.multiple_input:
|
| 1456 |
-
# If there are multiple inputs, choices are placed in the ctx
|
| 1457 |
-
# apply chat_template to choices if apply_chat_template
|
| 1458 |
-
cont = self.doc_to_target(doc)
|
| 1459 |
-
|
| 1460 |
-
arguments = [
|
| 1461 |
-
(
|
| 1462 |
-
ctx
|
| 1463 |
-
+ (
|
| 1464 |
-
chat_template([{"role": "user", "content": choice}])
|
| 1465 |
-
if apply_chat_template
|
| 1466 |
-
else choice
|
| 1467 |
-
),
|
| 1468 |
-
f"{target_delimiter}{cont}",
|
| 1469 |
-
)
|
| 1470 |
-
for choice in choices
|
| 1471 |
-
]
|
| 1472 |
-
else:
|
| 1473 |
-
# Otherwise they are placed in the continuation
|
| 1474 |
-
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
|
| 1475 |
-
|
| 1476 |
-
# TODO: we should raise a warning telling users this will at most ~2x runtime.
|
| 1477 |
-
if "acc_mutual_info" in self._metric_fn_list.keys():
|
| 1478 |
-
# if we are calculating multiple choice accuracy
|
| 1479 |
-
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
|
| 1480 |
-
|
| 1481 |
-
# here mutual info refers to calculating
|
| 1482 |
-
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
|
| 1483 |
-
# in other words normalizing by subtracting the unconditional logprob of each choice.
|
| 1484 |
-
# TODO: should these be strided? will have to modify the processing in process_results if so
|
| 1485 |
-
aux_arguments = [
|
| 1486 |
-
("", f"{target_delimiter}{choice}") for choice in choices
|
| 1487 |
-
]
|
| 1488 |
-
|
| 1489 |
-
arguments.extend(aux_arguments)
|
| 1490 |
-
|
| 1491 |
-
elif self.OUTPUT_TYPE == "generate_until":
|
| 1492 |
-
arguments = (ctx, deepcopy(self.config.generation_kwargs))
|
| 1493 |
-
|
| 1494 |
-
multimodal_arg = {}
|
| 1495 |
-
if (
|
| 1496 |
-
self.config.doc_to_image
|
| 1497 |
-
): # TODO: ensure that non-multimodal tasks aren't getting visual args
|
| 1498 |
-
multimodal_arg = {
|
| 1499 |
-
**multimodal_arg,
|
| 1500 |
-
**{"visual": self.doc_to_image(doc)},
|
| 1501 |
-
}
|
| 1502 |
-
|
| 1503 |
-
if (
|
| 1504 |
-
self.config.doc_to_audio
|
| 1505 |
-
): # TODO: ensure that non-multimodal tasks aren't getting audio args
|
| 1506 |
-
multimodal_arg = {
|
| 1507 |
-
**multimodal_arg,
|
| 1508 |
-
**{"audio": self.doc_to_audio(doc)},
|
| 1509 |
-
}
|
| 1510 |
-
|
| 1511 |
-
if bool(multimodal_arg):
|
| 1512 |
-
if isinstance(arguments, list):
|
| 1513 |
-
arguments = [arg + (multimodal_arg,) for arg in arguments]
|
| 1514 |
-
else:
|
| 1515 |
-
arguments = arguments + (multimodal_arg,)
|
| 1516 |
-
|
| 1517 |
-
if self.OUTPUT_TYPE == "multiple_choice":
|
| 1518 |
-
request_list = [
|
| 1519 |
-
Instance(
|
| 1520 |
-
request_type="loglikelihood",
|
| 1521 |
-
doc=doc,
|
| 1522 |
-
arguments=arg,
|
| 1523 |
-
idx=i,
|
| 1524 |
-
**kwargs,
|
| 1525 |
-
)
|
| 1526 |
-
for i, arg in enumerate(arguments)
|
| 1527 |
-
]
|
| 1528 |
-
|
| 1529 |
-
return request_list
|
| 1530 |
-
|
| 1531 |
-
return Instance(
|
| 1532 |
-
request_type=self.OUTPUT_TYPE,
|
| 1533 |
-
doc=doc,
|
| 1534 |
-
arguments=arguments,
|
| 1535 |
-
idx=0,
|
| 1536 |
-
**kwargs,
|
| 1537 |
-
)
|
| 1538 |
-
|
| 1539 |
-
def process_results(self, doc, results):
|
| 1540 |
-
if callable(self.config.process_results):
|
| 1541 |
-
return self.config.process_results(doc, results)
|
| 1542 |
-
|
| 1543 |
-
result_dict = {}
|
| 1544 |
-
use_metric = list(self._metric_fn_list.keys())
|
| 1545 |
-
if self.OUTPUT_TYPE == "loglikelihood":
|
| 1546 |
-
results = results[0]
|
| 1547 |
-
ll, is_greedy = results
|
| 1548 |
-
return {
|
| 1549 |
-
**({"perplexity": ll} if "perplexity" in use_metric else {}),
|
| 1550 |
-
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
|
| 1551 |
-
}
|
| 1552 |
-
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
|
| 1553 |
-
(loglikelihood,) = results
|
| 1554 |
-
_words = self.count_words(self.doc_to_target(doc))
|
| 1555 |
-
_bytes = self.count_bytes(self.doc_to_target(doc))
|
| 1556 |
-
return {
|
| 1557 |
-
**(
|
| 1558 |
-
{"word_perplexity": (loglikelihood, _words)}
|
| 1559 |
-
if "word_perplexity" in use_metric
|
| 1560 |
-
else {}
|
| 1561 |
-
),
|
| 1562 |
-
**(
|
| 1563 |
-
{"byte_perplexity": (loglikelihood, _bytes)}
|
| 1564 |
-
if "byte_perplexity" in use_metric
|
| 1565 |
-
else {}
|
| 1566 |
-
),
|
| 1567 |
-
**(
|
| 1568 |
-
{"bits_per_byte": (loglikelihood, _bytes)}
|
| 1569 |
-
if "bits_per_byte" in use_metric
|
| 1570 |
-
else {}
|
| 1571 |
-
),
|
| 1572 |
-
}
|
| 1573 |
-
elif self.OUTPUT_TYPE == "multiple_choice":
|
| 1574 |
-
lls, is_greedy = zip(*results)
|
| 1575 |
-
|
| 1576 |
-
# retrieve choices in List[str] form, to compute choice lengths, etc.
|
| 1577 |
-
choices = self.doc_to_choice(doc)
|
| 1578 |
-
completion_len = np.array([float(len(i)) for i in choices])
|
| 1579 |
-
|
| 1580 |
-
if (
|
| 1581 |
-
2 * len(choices) == len(lls)
|
| 1582 |
-
and "acc_mutual_info" in self._metric_fn_list.keys()
|
| 1583 |
-
):
|
| 1584 |
-
# then we are doing mutual info.
|
| 1585 |
-
# this stores the "dryrun" / unconditional answer loglikelihoods
|
| 1586 |
-
# as we extend the args list with unconditional ("", continuation) pairs
|
| 1587 |
-
lls_unconditional = lls[len(choices) :]
|
| 1588 |
-
if len(lls_unconditional) != len(choices):
|
| 1589 |
-
raise ValueError
|
| 1590 |
-
# and this stores our "regular" conditional loglikelihoods
|
| 1591 |
-
lls = lls[: len(choices)]
|
| 1592 |
-
|
| 1593 |
-
pred = np.argmax(lls)
|
| 1594 |
-
pred_norm = np.argmax(lls / completion_len)
|
| 1595 |
-
|
| 1596 |
-
if self.multiple_input:
|
| 1597 |
-
gold = self.doc_to_text(doc)
|
| 1598 |
-
else:
|
| 1599 |
-
gold = self.doc_to_target(doc)
|
| 1600 |
-
|
| 1601 |
-
gold_index_error = False
|
| 1602 |
-
if isinstance(gold, list):
|
| 1603 |
-
gold = [i if i < len(choices) else -100 for i in gold]
|
| 1604 |
-
if -100 in gold:
|
| 1605 |
-
gold_index_error = True
|
| 1606 |
-
else:
|
| 1607 |
-
if isinstance(gold, int):
|
| 1608 |
-
gold = gold if gold < len(choices) else -100
|
| 1609 |
-
elif isinstance(gold, str):
|
| 1610 |
-
gold = choices.index(gold) if gold in choices else -100
|
| 1611 |
-
|
| 1612 |
-
if gold == -100:
|
| 1613 |
-
gold_index_error = True
|
| 1614 |
-
|
| 1615 |
-
if gold_index_error:
|
| 1616 |
-
eval_logger.warning(
|
| 1617 |
-
f"Label index was not in within range of available choices,"
|
| 1618 |
-
f"Sample:\n\n{doc}\n\n"
|
| 1619 |
-
)
|
| 1620 |
-
|
| 1621 |
-
if self.multiple_target:
|
| 1622 |
-
acc = 1.0 if pred in gold else 0.0
|
| 1623 |
-
acc_norm = 1.0 if pred_norm in gold else 0.0
|
| 1624 |
-
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
|
| 1625 |
-
else:
|
| 1626 |
-
acc = 1.0 if pred == gold else 0.0
|
| 1627 |
-
acc_norm = 1.0 if pred_norm == gold else 0.0
|
| 1628 |
-
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
|
| 1629 |
-
exact_match = int(is_greedy[gold]) if gold != -100 else 0
|
| 1630 |
-
|
| 1631 |
-
prob_norm = utils.softmax(lls)
|
| 1632 |
-
|
| 1633 |
-
# TODO use keyword arguments to the metric?
|
| 1634 |
-
# gold, pred, norm stuff, the original lls,
|
| 1635 |
-
result_dict = {
|
| 1636 |
-
**({"acc": acc} if "acc" in use_metric else {}),
|
| 1637 |
-
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
|
| 1638 |
-
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
|
| 1639 |
-
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
|
| 1640 |
-
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
|
| 1641 |
-
**(
|
| 1642 |
-
{"brier_score": (gold, prob_norm)}
|
| 1643 |
-
if "brier_score" in use_metric
|
| 1644 |
-
else {}
|
| 1645 |
-
),
|
| 1646 |
-
}
|
| 1647 |
-
|
| 1648 |
-
if "acc_mutual_info" in use_metric:
|
| 1649 |
-
lls_mutual_info = [
|
| 1650 |
-
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
|
| 1651 |
-
]
|
| 1652 |
-
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
|
| 1653 |
-
result_dict["acc_mutual_info"] = acc_mutual_info
|
| 1654 |
-
|
| 1655 |
-
elif self.OUTPUT_TYPE == "generate_until":
|
| 1656 |
-
gold = self.doc_to_target(doc)
|
| 1657 |
-
result = results[0]
|
| 1658 |
-
if self.config.doc_to_choice is not None:
|
| 1659 |
-
# If you set doc_to_choice,
|
| 1660 |
-
# it assumes that doc_to_target returns a number.
|
| 1661 |
-
choices = self.doc_to_choice(doc)
|
| 1662 |
-
gold = choices[gold]
|
| 1663 |
-
# we expect multiple_targets to be a list.
|
| 1664 |
-
elif self.multiple_target:
|
| 1665 |
-
gold = list(gold)
|
| 1666 |
-
# TODO: handle this better
|
| 1667 |
-
elif type(gold) is not type(result) and not (
|
| 1668 |
-
"bypass" in self._metric_fn_list.keys() or isinstance(result, list)
|
| 1669 |
-
):
|
| 1670 |
-
# cast gold to the same type as result
|
| 1671 |
-
gold = type(result)(gold)
|
| 1672 |
-
|
| 1673 |
-
for metric in self._metric_fn_list.keys():
|
| 1674 |
-
if self.multiple_target:
|
| 1675 |
-
# in the case where we have multiple targets,
|
| 1676 |
-
# return true if any are true
|
| 1677 |
-
# TODO: this may break for multipLe_target, non zero-or-1 metrics
|
| 1678 |
-
scores = []
|
| 1679 |
-
if not isinstance(gold, list):
|
| 1680 |
-
# sometimes, a multiple_target dataset has exceptions where one doc has only one string answer
|
| 1681 |
-
# print(gold)
|
| 1682 |
-
gold = [gold]
|
| 1683 |
-
if metric == "exact_match":
|
| 1684 |
-
result = [result for _ in range(len(gold))]
|
| 1685 |
-
scores = self._metric_fn_list[metric](
|
| 1686 |
-
references=gold,
|
| 1687 |
-
predictions=result,
|
| 1688 |
-
**self._metric_fn_kwargs[metric],
|
| 1689 |
-
)[metric]
|
| 1690 |
-
result_score = 1.0 if scores > 0.0 else 0.0
|
| 1691 |
-
else:
|
| 1692 |
-
for gold_option in gold:
|
| 1693 |
-
try:
|
| 1694 |
-
result_score = self._metric_fn_list[metric](
|
| 1695 |
-
references=[gold_option],
|
| 1696 |
-
predictions=[result],
|
| 1697 |
-
**self._metric_fn_kwargs[metric],
|
| 1698 |
-
)
|
| 1699 |
-
except (
|
| 1700 |
-
TypeError
|
| 1701 |
-
): # TODO: this is hacky and I don't want to do it
|
| 1702 |
-
result_score = self._metric_fn_list[metric](
|
| 1703 |
-
[gold_option, result]
|
| 1704 |
-
)
|
| 1705 |
-
if isinstance(result_score, dict):
|
| 1706 |
-
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1707 |
-
result_score = result_score[metric]
|
| 1708 |
-
scores.append(result_score)
|
| 1709 |
-
if any(scores):
|
| 1710 |
-
result_score = 1.0
|
| 1711 |
-
else:
|
| 1712 |
-
result_score = 0.0
|
| 1713 |
-
else:
|
| 1714 |
-
try:
|
| 1715 |
-
result_score = self._metric_fn_list[metric](
|
| 1716 |
-
references=[gold],
|
| 1717 |
-
predictions=[result],
|
| 1718 |
-
**self._metric_fn_kwargs[metric],
|
| 1719 |
-
)
|
| 1720 |
-
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
|
| 1721 |
-
result_score = self._metric_fn_list[metric]([gold, result])
|
| 1722 |
-
if isinstance(result_score, dict):
|
| 1723 |
-
# TODO: this handles the case where HF evaluate returns a dict.
|
| 1724 |
-
# This allows for multiple metrics to be returned from the same function
|
| 1725 |
-
for k, v in result_score.items():
|
| 1726 |
-
result_dict[k] = v
|
| 1727 |
-
else:
|
| 1728 |
-
result_dict[metric] = result_score
|
| 1729 |
-
else:
|
| 1730 |
-
raise ValueError(
|
| 1731 |
-
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
|
| 1732 |
-
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
|
| 1733 |
-
)
|
| 1734 |
-
|
| 1735 |
-
return result_dict
|
| 1736 |
-
|
| 1737 |
-
def aggregation(self) -> dict:
|
| 1738 |
-
return self._aggregation_list
|
| 1739 |
-
|
| 1740 |
-
def higher_is_better(self) -> dict:
|
| 1741 |
-
return self._higher_is_better
|
| 1742 |
-
|
| 1743 |
-
def get_config(self, key: str) -> Any:
|
| 1744 |
-
return getattr(self._config, key, None)
|
| 1745 |
-
|
| 1746 |
-
@property
|
| 1747 |
-
def task_name(self) -> Any:
|
| 1748 |
-
return getattr(self.config, "task", None)
|
| 1749 |
-
|
| 1750 |
-
def __repr__(self):
|
| 1751 |
-
return (
|
| 1752 |
-
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
|
| 1753 |
-
f"output_type={self.OUTPUT_TYPE},"
|
| 1754 |
-
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
|
| 1755 |
-
f"num_samples={len(self.eval_docs)})"
|
| 1756 |
-
)
|
| 1757 |
-
|
| 1758 |
-
|
| 1759 |
-
class MultipleChoiceTask(Task):
|
| 1760 |
-
OUTPUT_TYPE = "loglikelihood"
|
| 1761 |
-
|
| 1762 |
-
def doc_to_target(self, doc: dict) -> str:
|
| 1763 |
-
return " " + doc["choices"][doc["gold"]]
|
| 1764 |
-
|
| 1765 |
-
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
|
| 1766 |
-
# TODO: add mutual info here?
|
| 1767 |
-
return [
|
| 1768 |
-
Instance(
|
| 1769 |
-
request_type="loglikelihood",
|
| 1770 |
-
doc=doc,
|
| 1771 |
-
arguments=(ctx, " {}".format(choice)),
|
| 1772 |
-
idx=i,
|
| 1773 |
-
**kwargs,
|
| 1774 |
-
)
|
| 1775 |
-
for i, choice in enumerate(doc["choices"])
|
| 1776 |
-
]
|
| 1777 |
-
|
| 1778 |
-
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
|
| 1779 |
-
results = [
|
| 1780 |
-
res[0] for res in results
|
| 1781 |
-
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
|
| 1782 |
-
gold = doc["gold"]
|
| 1783 |
-
|
| 1784 |
-
acc = 1.0 if np.argmax(results) == gold else 0.0
|
| 1785 |
-
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
| 1786 |
-
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
| 1787 |
-
|
| 1788 |
-
return {
|
| 1789 |
-
"acc": acc,
|
| 1790 |
-
"acc_norm": acc_norm,
|
| 1791 |
-
}
|
| 1792 |
-
|
| 1793 |
-
def higher_is_better(self) -> dict:
|
| 1794 |
-
return {
|
| 1795 |
-
"acc": True,
|
| 1796 |
-
"acc_norm": True,
|
| 1797 |
-
}
|
| 1798 |
-
|
| 1799 |
-
def aggregation(self) -> dict:
|
| 1800 |
-
return {
|
| 1801 |
-
"acc": mean,
|
| 1802 |
-
"acc_norm": mean,
|
| 1803 |
-
}
|
| 1804 |
-
|
| 1805 |
-
|
| 1806 |
-
class PerplexityTask(Task):
|
| 1807 |
-
OUTPUT_TYPE = "loglikelihood_rolling"
|
| 1808 |
-
|
| 1809 |
-
def has_training_docs(self) -> bool:
|
| 1810 |
-
return False
|
| 1811 |
-
|
| 1812 |
-
def fewshot_examples(self, k: int, rnd) -> List:
|
| 1813 |
-
if k != 0:
|
| 1814 |
-
raise ValueError(
|
| 1815 |
-
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1816 |
-
)
|
| 1817 |
-
return []
|
| 1818 |
-
|
| 1819 |
-
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
|
| 1820 |
-
if num_fewshot != 0:
|
| 1821 |
-
raise ValueError(
|
| 1822 |
-
"The number of fewshot examples must be 0 for perplexity tasks."
|
| 1823 |
-
)
|
| 1824 |
-
|
| 1825 |
-
return ""
|
| 1826 |
-
|
| 1827 |
-
def higher_is_better(self) -> dict:
|
| 1828 |
-
return {
|
| 1829 |
-
"word_perplexity": False,
|
| 1830 |
-
"byte_perplexity": False,
|
| 1831 |
-
"bits_per_byte": False,
|
| 1832 |
-
}
|
| 1833 |
-
|
| 1834 |
-
def doc_to_decontamination_query(self, doc):
|
| 1835 |
-
return doc
|
| 1836 |
-
|
| 1837 |
-
def doc_to_text(self, doc) -> str:
|
| 1838 |
-
return ""
|
| 1839 |
-
|
| 1840 |
-
def doc_to_target(self, doc):
|
| 1841 |
-
return doc
|
| 1842 |
-
|
| 1843 |
-
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
|
| 1844 |
-
if bool(ctx):
|
| 1845 |
-
raise ValueError
|
| 1846 |
-
|
| 1847 |
-
return Instance(
|
| 1848 |
-
request_type=self.OUTPUT_TYPE,
|
| 1849 |
-
doc=doc,
|
| 1850 |
-
arguments=(self.doc_to_target(doc),),
|
| 1851 |
-
idx=0,
|
| 1852 |
-
**kwargs,
|
| 1853 |
-
)
|
| 1854 |
-
|
| 1855 |
-
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
|
| 1856 |
-
(loglikelihood,) = results
|
| 1857 |
-
words = self.count_words(self.doc_to_target(doc))
|
| 1858 |
-
bytes_ = self.count_bytes(self.doc_to_target(doc))
|
| 1859 |
-
return {
|
| 1860 |
-
"word_perplexity": (loglikelihood, words),
|
| 1861 |
-
"byte_perplexity": (loglikelihood, bytes_),
|
| 1862 |
-
"bits_per_byte": (loglikelihood, bytes_),
|
| 1863 |
-
}
|
| 1864 |
-
|
| 1865 |
-
def aggregation(self) -> dict:
|
| 1866 |
-
return {
|
| 1867 |
-
"word_perplexity": weighted_perplexity,
|
| 1868 |
-
"byte_perplexity": weighted_perplexity,
|
| 1869 |
-
"bits_per_byte": bits_per_byte,
|
| 1870 |
-
}
|
| 1871 |
-
|
| 1872 |
-
@classmethod
|
| 1873 |
-
def count_bytes(cls, doc) -> int:
|
| 1874 |
-
return len(doc.encode("utf-8"))
|
| 1875 |
-
|
| 1876 |
-
@classmethod
|
| 1877 |
-
def count_words(cls, doc) -> int:
|
| 1878 |
-
"""Downstream tasks with custom word boundaries should override this!"""
|
| 1879 |
-
return len(re.split(r"\s+", doc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/caching/cache.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
import hashlib
|
| 2 |
-
import logging
|
| 3 |
-
import os
|
| 4 |
-
|
| 5 |
-
import dill
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
eval_logger = logging.getLogger(__name__)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
MODULE_DIR = os.path.dirname(os.path.realpath(__file__))
|
| 12 |
-
|
| 13 |
-
OVERRIDE_PATH = os.getenv("LM_HARNESS_CACHE_PATH")
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
PATH = OVERRIDE_PATH if OVERRIDE_PATH else f"{MODULE_DIR}/.cache"
|
| 17 |
-
|
| 18 |
-
# This should be sufficient for uniqueness
|
| 19 |
-
HASH_INPUT = "EleutherAI-lm-evaluation-harness"
|
| 20 |
-
|
| 21 |
-
HASH_PREFIX = hashlib.sha256(HASH_INPUT.encode("utf-8")).hexdigest()
|
| 22 |
-
|
| 23 |
-
FILE_SUFFIX = f".{HASH_PREFIX}.pickle"
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def load_from_cache(file_name: str, cache: bool = False):
|
| 27 |
-
if not cache:
|
| 28 |
-
return
|
| 29 |
-
try:
|
| 30 |
-
path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 31 |
-
|
| 32 |
-
with open(path, "rb") as file:
|
| 33 |
-
cached_task_dict = dill.loads(file.read())
|
| 34 |
-
return cached_task_dict
|
| 35 |
-
|
| 36 |
-
except Exception:
|
| 37 |
-
eval_logger.debug(f"{file_name} is not cached, generating...")
|
| 38 |
-
pass
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def save_to_cache(file_name, obj):
|
| 42 |
-
if not os.path.exists(PATH):
|
| 43 |
-
os.mkdir(PATH)
|
| 44 |
-
|
| 45 |
-
file_path = f"{PATH}/{file_name}{FILE_SUFFIX}"
|
| 46 |
-
|
| 47 |
-
eval_logger.debug(f"Saving {file_path} to cache...")
|
| 48 |
-
with open(file_path, "wb") as file:
|
| 49 |
-
file.write(dill.dumps(obj))
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# NOTE the "key" param is to allow for flexibility
|
| 53 |
-
def delete_cache(key: str = ""):
|
| 54 |
-
files = os.listdir(PATH)
|
| 55 |
-
|
| 56 |
-
for file in files:
|
| 57 |
-
if file.startswith(key) and file.endswith(FILE_SUFFIX):
|
| 58 |
-
file_path = f"{PATH}/{file}"
|
| 59 |
-
os.unlink(file_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/decontamination/__init__.py
DELETED
|
File without changes
|
lm-evaluation-harness/lm_eval/decontamination/archiver.py
DELETED
|
@@ -1,174 +0,0 @@
|
|
| 1 |
-
import datetime
|
| 2 |
-
import io
|
| 3 |
-
import json
|
| 4 |
-
import mmap
|
| 5 |
-
import os
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import jsonlines
|
| 10 |
-
import tqdm
|
| 11 |
-
import zstandard
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def json_serial(obj: Any) -> str:
|
| 15 |
-
"""JSON serializer for objects not serializable by default json code"""
|
| 16 |
-
|
| 17 |
-
if isinstance(obj, (datetime.datetime,)):
|
| 18 |
-
return obj.isoformat()
|
| 19 |
-
raise TypeError("Type %s not serializable" % type(obj))
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# Modified version of lm_dataformat Archive for single file.
|
| 23 |
-
class Archive:
|
| 24 |
-
def __init__(self, file_path: str, compression_level: int = 3) -> None:
|
| 25 |
-
self.file_path = file_path
|
| 26 |
-
dir_name = os.path.dirname(file_path)
|
| 27 |
-
if dir_name:
|
| 28 |
-
os.makedirs(dir_name, exist_ok=True)
|
| 29 |
-
self.fh = open(self.file_path, "wb")
|
| 30 |
-
self.cctx = zstandard.ZstdCompressor(level=compression_level)
|
| 31 |
-
self.compressor = self.cctx.stream_writer(self.fh)
|
| 32 |
-
|
| 33 |
-
def add_data(self, data, meta=None) -> None:
|
| 34 |
-
if meta is None:
|
| 35 |
-
meta = {}
|
| 36 |
-
self.compressor.write(
|
| 37 |
-
json.dumps({"text": data, "meta": meta}, default=json_serial).encode(
|
| 38 |
-
"UTF-8"
|
| 39 |
-
)
|
| 40 |
-
+ b"\n"
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
def commit(self) -> None:
|
| 44 |
-
self.compressor.flush(zstandard.FLUSH_FRAME)
|
| 45 |
-
self.fh.flush()
|
| 46 |
-
self.fh.close()
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
|
| 50 |
-
class Reader:
|
| 51 |
-
def __init__(self) -> None:
|
| 52 |
-
pass
|
| 53 |
-
|
| 54 |
-
def read(
|
| 55 |
-
self,
|
| 56 |
-
file,
|
| 57 |
-
get_meta: bool = False,
|
| 58 |
-
autojoin_paragraphs: bool = True,
|
| 59 |
-
para_joiner: str = "\n\n",
|
| 60 |
-
):
|
| 61 |
-
with open(file, "rb") as fh:
|
| 62 |
-
self.fh = fh
|
| 63 |
-
cctx = zstandard.ZstdDecompressor()
|
| 64 |
-
reader = io.BufferedReader(cctx.stream_reader(fh))
|
| 65 |
-
rdr = jsonlines.Reader(reader)
|
| 66 |
-
for ob in rdr:
|
| 67 |
-
# naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
|
| 68 |
-
if isinstance(ob, str):
|
| 69 |
-
assert not get_meta
|
| 70 |
-
yield ob
|
| 71 |
-
continue
|
| 72 |
-
|
| 73 |
-
text = ob["text"]
|
| 74 |
-
|
| 75 |
-
if autojoin_paragraphs and isinstance(text, list):
|
| 76 |
-
text = para_joiner.join(text)
|
| 77 |
-
|
| 78 |
-
if get_meta:
|
| 79 |
-
yield text, (ob["meta"] if "meta" in ob else {})
|
| 80 |
-
else:
|
| 81 |
-
yield text
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
class TextArchive:
|
| 85 |
-
def __init__(self, file_path, mode: str = "rb+") -> None:
|
| 86 |
-
self.file_path = file_path
|
| 87 |
-
dir_name = os.path.dirname(file_path)
|
| 88 |
-
if dir_name:
|
| 89 |
-
os.makedirs(dir_name, exist_ok=True)
|
| 90 |
-
|
| 91 |
-
if not os.path.exists(file_path):
|
| 92 |
-
Path(file_path).touch()
|
| 93 |
-
|
| 94 |
-
self.fh = open(self.file_path, mode)
|
| 95 |
-
|
| 96 |
-
def add_data(self, data) -> None:
|
| 97 |
-
self.fh.write(data.encode("UTF-8") + b"\n")
|
| 98 |
-
|
| 99 |
-
def commit(self) -> None:
|
| 100 |
-
self.fh.flush()
|
| 101 |
-
self.fh.close()
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class TextReader:
|
| 105 |
-
def __init__(self, file_path) -> None:
|
| 106 |
-
self.file_path = file_path
|
| 107 |
-
|
| 108 |
-
# Optimized mmap read with infrequent tqdm updates to maintain speed
|
| 109 |
-
# Tested up to 250MB/s.
|
| 110 |
-
def read_tqdm(self, update_frequency: int = 10000):
|
| 111 |
-
current_file_position = 0
|
| 112 |
-
line_counter = 0
|
| 113 |
-
with (
|
| 114 |
-
open(self.file_path, "r", encoding="utf-8") as fh,
|
| 115 |
-
tqdm.tqdm(
|
| 116 |
-
total=os.path.getsize(self.file_path),
|
| 117 |
-
dynamic_ncols=True,
|
| 118 |
-
unit="byte",
|
| 119 |
-
unit_scale=1,
|
| 120 |
-
) as progress,
|
| 121 |
-
):
|
| 122 |
-
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 123 |
-
for line in iter(mmap_obj.readline, b""):
|
| 124 |
-
line = line.decode("utf-8")
|
| 125 |
-
line_counter += 1
|
| 126 |
-
if line_counter == update_frequency:
|
| 127 |
-
new_file_pos = mmap_obj.tell()
|
| 128 |
-
bytes_read = new_file_pos - current_file_position
|
| 129 |
-
current_file_position = new_file_pos
|
| 130 |
-
progress.update(bytes_read)
|
| 131 |
-
line_counter = 0
|
| 132 |
-
yield line[:-1]
|
| 133 |
-
|
| 134 |
-
def read_and_tell(self):
|
| 135 |
-
current_file_position = 0
|
| 136 |
-
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 137 |
-
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 138 |
-
for line in iter(mmap_obj.readline, b""):
|
| 139 |
-
line = line.decode("utf-8")
|
| 140 |
-
new_file_pos = mmap_obj.tell()
|
| 141 |
-
raw_bytes_read = new_file_pos - current_file_position
|
| 142 |
-
current_file_position = new_file_pos
|
| 143 |
-
yield line[:-1], raw_bytes_read
|
| 144 |
-
|
| 145 |
-
def read(self):
|
| 146 |
-
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 147 |
-
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
|
| 148 |
-
for line in iter(mmap_obj.readline, b""):
|
| 149 |
-
line = line.decode("utf-8")
|
| 150 |
-
yield line[:-1]
|
| 151 |
-
|
| 152 |
-
def read_slow(self):
|
| 153 |
-
with open(self.file_path, "r", encoding="utf8") as fh:
|
| 154 |
-
while True:
|
| 155 |
-
line = fh.readline()
|
| 156 |
-
if line == -1 or line == "":
|
| 157 |
-
break
|
| 158 |
-
else:
|
| 159 |
-
yield line[:-1]
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
# Optimized for speed. Decompresses the archive in shell before
|
| 163 |
-
# using the mmap'd TextReader.
|
| 164 |
-
class ZStdTextReader:
|
| 165 |
-
def __init__(self, file) -> None:
|
| 166 |
-
self.file = file
|
| 167 |
-
|
| 168 |
-
def read_tqdm(self):
|
| 169 |
-
decompressed_file = self.file[:-4]
|
| 170 |
-
print("Decompressing file, please wait...")
|
| 171 |
-
os.system(f"zstd -d {self.file}") # linux decompress is faster
|
| 172 |
-
reader = TextReader(decompressed_file)
|
| 173 |
-
yield from reader.read_tqdm()
|
| 174 |
-
os.remove(decompressed_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/decontamination/decontaminate.py
DELETED
|
@@ -1,166 +0,0 @@
|
|
| 1 |
-
import collections
|
| 2 |
-
import glob
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import pickle
|
| 6 |
-
import random
|
| 7 |
-
import time
|
| 8 |
-
|
| 9 |
-
from .archiver import ZStdTextReader
|
| 10 |
-
from .janitor import Janitor, word_ngrams
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
# Was used for testing the evaluator decoupled from the full logic below
|
| 14 |
-
def get_train_overlap_stub(docs: dict, ngrams_path: str, ngrams_n_size: str):
|
| 15 |
-
simulated_overlap = 0.1
|
| 16 |
-
contaminated = int(len(docs) * simulated_overlap)
|
| 17 |
-
return random.sample(range(len(docs)), contaminated)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# Returns a dictionary containing all overlapping documents in each
|
| 21 |
-
# task. In the standard use case, an overlap occurs when any of the 13-grams
|
| 22 |
-
# found in the task document exist in the training set documents.
|
| 23 |
-
#
|
| 24 |
-
# To generate 13-grams for the pile see scripts/clean_training_data. The final output of these
|
| 25 |
-
# scripts are an info.json file containing the n_gram_size (13) and a bunch of "ngrams_{x}.bkt.txt.sorted.zst"
|
| 26 |
-
# files. These should exist in the "ngrams_path" provided to this function.
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
# Algorithm:
|
| 30 |
-
# 1. Build lookups for each dataset {ngram: list(document_ids)}
|
| 31 |
-
# 2. Merge into an overall lookup {ngram: [(task_name, task_set, doc_ids),]}
|
| 32 |
-
# 3. Full scan the 13-grams from the training set against the merged lookup,
|
| 33 |
-
# saving matches in the "duplicates" dictionary {(task_name, task_set): set(doc_ids)}
|
| 34 |
-
# 4. Strip the task_set from the dictionary keys and return
|
| 35 |
-
#
|
| 36 |
-
# We cache the task+set lookups as well as the overlaps.
|
| 37 |
-
def get_train_overlap(docs_by_task_set: dict, ngrams_path: str, limit: int) -> dict:
|
| 38 |
-
# return get_train_overlap_stub(docs, ngrams_path, ngrams_n_size)
|
| 39 |
-
|
| 40 |
-
info_dict_path = os.path.join(ngrams_path, "info.json")
|
| 41 |
-
info_dict = json.load(open(info_dict_path, "r", encoding="utf-8"))
|
| 42 |
-
ngrams_n_size = info_dict["ngram_size"]
|
| 43 |
-
|
| 44 |
-
janitor = Janitor()
|
| 45 |
-
|
| 46 |
-
# Build lookup for each dataset first in case we use different task combinations later
|
| 47 |
-
print("Building Lookups...")
|
| 48 |
-
start = time.perf_counter()
|
| 49 |
-
|
| 50 |
-
def get_overlaps_dump_path(task_name, task_set, ngrams_n_size, limit) -> str:
|
| 51 |
-
return f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.overlaps"
|
| 52 |
-
|
| 53 |
-
lookups = {}
|
| 54 |
-
duplicates = {} # (task_name, task_set): set(doc_ids)}
|
| 55 |
-
sets_to_decontaminate = len(docs_by_task_set.keys())
|
| 56 |
-
|
| 57 |
-
for (task_name, task_set), docs in docs_by_task_set.items():
|
| 58 |
-
if not os.path.exists(f"data/{task_name}"):
|
| 59 |
-
os.mkdir(f"data/{task_name}")
|
| 60 |
-
|
| 61 |
-
# Check if we've decontaminated this combination before
|
| 62 |
-
overlaps_dump_path = get_overlaps_dump_path(
|
| 63 |
-
task_name, task_set, ngrams_n_size, limit
|
| 64 |
-
)
|
| 65 |
-
if os.path.exists(overlaps_dump_path):
|
| 66 |
-
duplicates[(task_name, task_set)] = pickle.load(
|
| 67 |
-
open(overlaps_dump_path, "rb")
|
| 68 |
-
)
|
| 69 |
-
sets_to_decontaminate -= 1
|
| 70 |
-
continue
|
| 71 |
-
else:
|
| 72 |
-
duplicates[(task_name, task_set)] = set()
|
| 73 |
-
|
| 74 |
-
# Build/load the task lookup {ngram: set(documents)}.
|
| 75 |
-
task_set_lookup_path = (
|
| 76 |
-
f"data/{task_name}/{task_set}_{ngrams_n_size}grams_limit{limit}.lookup"
|
| 77 |
-
)
|
| 78 |
-
if os.path.exists(task_set_lookup_path):
|
| 79 |
-
print(f"{task_set_lookup_path} available, loading...")
|
| 80 |
-
lookups[(task_name, task_set)] = pickle.load(
|
| 81 |
-
open(task_set_lookup_path, "rb")
|
| 82 |
-
)
|
| 83 |
-
else:
|
| 84 |
-
print(f"{task_set_lookup_path} not available, building...")
|
| 85 |
-
lookup = collections.defaultdict(set)
|
| 86 |
-
|
| 87 |
-
for doc_id, document in enumerate(docs):
|
| 88 |
-
ngrams = word_ngrams(janitor.normalize_string(document), ngrams_n_size)
|
| 89 |
-
for ngram in ngrams:
|
| 90 |
-
lookup[ngram].add(doc_id)
|
| 91 |
-
|
| 92 |
-
pickle.dump(lookup, open(task_set_lookup_path, "wb"))
|
| 93 |
-
lookups[(task_name, task_set)] = lookup
|
| 94 |
-
|
| 95 |
-
elapsed = time.perf_counter() - start
|
| 96 |
-
print(f"Building lookups took {elapsed:0.5f} seconds.")
|
| 97 |
-
|
| 98 |
-
matched_ngrams = []
|
| 99 |
-
|
| 100 |
-
if sets_to_decontaminate > 0:
|
| 101 |
-
print("Merging lookups...")
|
| 102 |
-
start = time.perf_counter()
|
| 103 |
-
merged_lookup = collections.defaultdict(list)
|
| 104 |
-
for (task_name, task_set), lookup in lookups.items():
|
| 105 |
-
for ngram, doc_ids in lookup.items():
|
| 106 |
-
merged_lookup[ngram].append((task_name, task_set, doc_ids))
|
| 107 |
-
|
| 108 |
-
elapsed = time.perf_counter() - start
|
| 109 |
-
print(f"Merging lookups took {elapsed:0.5f} seconds.")
|
| 110 |
-
|
| 111 |
-
print(f"{ngrams_n_size} grams files found in {ngrams_path}:")
|
| 112 |
-
files = glob.glob(os.path.join(ngrams_path, "*.sorted.zst"))
|
| 113 |
-
print(files)
|
| 114 |
-
|
| 115 |
-
for file in files:
|
| 116 |
-
start = time.perf_counter()
|
| 117 |
-
print(f"Scanning {file}")
|
| 118 |
-
reader = ZStdTextReader(file)
|
| 119 |
-
total_ngrams = 0
|
| 120 |
-
unique_ngrams = 0
|
| 121 |
-
matching_unique = 0
|
| 122 |
-
non_matching_unique = 0
|
| 123 |
-
|
| 124 |
-
current_ngram = ""
|
| 125 |
-
for line in reader.read_tqdm(): # Scan training set ngrams file
|
| 126 |
-
total_ngrams += 1
|
| 127 |
-
[ngram, document_id] = line.rsplit(" ", 1)
|
| 128 |
-
if (
|
| 129 |
-
ngram != current_ngram
|
| 130 |
-
): # Only need to match the ngram once in training set
|
| 131 |
-
unique_ngrams += 1
|
| 132 |
-
current_ngram = ngram
|
| 133 |
-
if ngram in merged_lookup:
|
| 134 |
-
matched_ngrams.append(ngram) # For logging
|
| 135 |
-
matching_unique += 1
|
| 136 |
-
for task_name, task_set, doc_ids in merged_lookup[ngram]:
|
| 137 |
-
task_doc_set = duplicates[(task_name, task_set)]
|
| 138 |
-
for doc_id in doc_ids: # Record contamination across all relevant task/set combos
|
| 139 |
-
task_doc_set.add(doc_id)
|
| 140 |
-
del merged_lookup[ngram] # No point matching again
|
| 141 |
-
else:
|
| 142 |
-
non_matching_unique += 1
|
| 143 |
-
|
| 144 |
-
print(f"Total Ngrams: {total_ngrams}")
|
| 145 |
-
print(f"Unique Ngrams: {unique_ngrams}")
|
| 146 |
-
print(f"Unique Matching: {matching_unique}")
|
| 147 |
-
print(f"Unique Non Matching: {non_matching_unique}")
|
| 148 |
-
print("Matched ngrams:")
|
| 149 |
-
for ngram in matched_ngrams:
|
| 150 |
-
print(ngram)
|
| 151 |
-
|
| 152 |
-
elapsed = time.perf_counter() - start
|
| 153 |
-
print(f"Read took {elapsed:0.5f} seconds.")
|
| 154 |
-
print(f"Speed: {(os.path.getsize(file) / 1000000.0) / elapsed}MB/second")
|
| 155 |
-
|
| 156 |
-
print(duplicates)
|
| 157 |
-
|
| 158 |
-
# Dump overlaps separately
|
| 159 |
-
for (task_name, task_set), doc_ids in duplicates.items():
|
| 160 |
-
overlaps_dump_path = get_overlaps_dump_path(
|
| 161 |
-
task_name, task_set, ngrams_n_size, limit
|
| 162 |
-
)
|
| 163 |
-
pickle.dump(doc_ids, open(overlaps_dump_path, "wb"))
|
| 164 |
-
|
| 165 |
-
# Strip task set and return
|
| 166 |
-
return {task_name: doc_ids for (task_name, task_set), doc_ids in duplicates.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/decontamination/janitor.py
DELETED
|
@@ -1,328 +0,0 @@
|
|
| 1 |
-
import pickle
|
| 2 |
-
import re
|
| 3 |
-
import string
|
| 4 |
-
import traceback
|
| 5 |
-
from typing import Iterator, List, Sequence, Tuple, TypeVar
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
# This is a cpp module. Compile janitor_util.cpp with:
|
| 9 |
-
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
|
| 10 |
-
try:
|
| 11 |
-
import janitor_util
|
| 12 |
-
|
| 13 |
-
JANITOR_CPP = True
|
| 14 |
-
except Exception:
|
| 15 |
-
print("WARNING: C++ module could not be loaded. Janitor running in python mode")
|
| 16 |
-
traceback.print_exc()
|
| 17 |
-
JANITOR_CPP = False
|
| 18 |
-
|
| 19 |
-
T = TypeVar("T")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
# Implementation from nltk source
|
| 23 |
-
# https://www.nltk.org/_modules/nltk/util.html
|
| 24 |
-
def form_ngrams(sequence: Iterator[T], n: int) -> Iterator[Tuple[T, ...]]:
|
| 25 |
-
history = []
|
| 26 |
-
while n > 1:
|
| 27 |
-
# PEP 479, prevent RuntimeError from being raised when StopIteration bubbles out of generator
|
| 28 |
-
try:
|
| 29 |
-
next_item = next(sequence)
|
| 30 |
-
except StopIteration:
|
| 31 |
-
# no more data, terminate the generator
|
| 32 |
-
return
|
| 33 |
-
history.append(next_item)
|
| 34 |
-
n -= 1
|
| 35 |
-
for item in sequence:
|
| 36 |
-
history.append(item)
|
| 37 |
-
yield tuple(history)
|
| 38 |
-
del history[0]
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def word_ngrams(s: str, n: int) -> Iterator[str]:
|
| 42 |
-
"""Splits a string into ngram words"""
|
| 43 |
-
tokens = s.split() # not a generator :(
|
| 44 |
-
ngram_seqs = form_ngrams(iter(tokens), n)
|
| 45 |
-
return (" ".join(ngram) for ngram in ngram_seqs)
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# Does character sequences only - combined faster function to play around with later
|
| 49 |
-
# def word_ngrams_indices_combined(sequence, n):
|
| 50 |
-
# current_word = ""
|
| 51 |
-
# history = []
|
| 52 |
-
# gap = False;
|
| 53 |
-
# start = 0
|
| 54 |
-
# end = 0
|
| 55 |
-
# for character in sequence:
|
| 56 |
-
# if character == " ":
|
| 57 |
-
# if not gap:
|
| 58 |
-
# gap = True
|
| 59 |
-
# history.append(current_word)
|
| 60 |
-
# end += len(current_word) - 1
|
| 61 |
-
# current_word = ""
|
| 62 |
-
# if len(history) == n:
|
| 63 |
-
# yield (tuple(history), start, end)
|
| 64 |
-
# del history[0]
|
| 65 |
-
# start = end + 1
|
| 66 |
-
# end = start
|
| 67 |
-
# else:
|
| 68 |
-
# gap = False
|
| 69 |
-
# current_word += character
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
# https://stackoverflow.com/questions/13734451/string-split-with-indices-in-python
|
| 73 |
-
def split_indices(s: str) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 74 |
-
"""Splits a string on whitespaces and records the indices of each in the original string.
|
| 75 |
-
@:return generator((word, (start_idx, end_idx)), ...)
|
| 76 |
-
"""
|
| 77 |
-
return ((m.group(0), (m.start(), m.end() - 1)) for m in re.finditer(r"\S+", s))
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def word_ngrams_indices(s: str, n: int) -> Iterator[Tuple[str, Tuple[int, int]]]:
|
| 81 |
-
"""Splits a string into pairs of (ngram words, their start/end indices)"""
|
| 82 |
-
tokens_with_indices = split_indices(s)
|
| 83 |
-
|
| 84 |
-
# Generator of ngrams of (word, idx_pairs)
|
| 85 |
-
# (
|
| 86 |
-
# [(word, (start,end)), (word, (start, end))...],
|
| 87 |
-
# [(word, (start, end)), ...],
|
| 88 |
-
# ...
|
| 89 |
-
# )
|
| 90 |
-
ngram_seqs_with_indices = form_ngrams(tokens_with_indices, n)
|
| 91 |
-
|
| 92 |
-
# Generator of pairs of word and index ngrams
|
| 93 |
-
# (
|
| 94 |
-
# ([word, word, ...], [(start,end), (start,end), ...]),
|
| 95 |
-
# ...
|
| 96 |
-
# )
|
| 97 |
-
ngram_indices_pairs = (
|
| 98 |
-
zip(*ngram_with_indices) for ngram_with_indices in ngram_seqs_with_indices
|
| 99 |
-
)
|
| 100 |
-
|
| 101 |
-
# Generator of ( (word_ngram, (start, end)), (word_ngram, start, end)), ...)
|
| 102 |
-
return (
|
| 103 |
-
(" ".join(ngram_seq), (indices[0][0], indices[-1][1]))
|
| 104 |
-
for ngram_seq, indices in ngram_indices_pairs
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class Janitor:
|
| 109 |
-
# FIXME delete_chars: Should anything else go here? Special chars?
|
| 110 |
-
def __init__(
|
| 111 |
-
self,
|
| 112 |
-
ngram_n: int = 13,
|
| 113 |
-
window_to_remove: int = 200,
|
| 114 |
-
too_dirty_cutoff: int = 10,
|
| 115 |
-
minimum_slice_length: int = 200,
|
| 116 |
-
delete_chars: str = string.punctuation,
|
| 117 |
-
) -> None:
|
| 118 |
-
self.ngram_n = ngram_n
|
| 119 |
-
self.window_to_remove = window_to_remove
|
| 120 |
-
self.too_dirty_cutoff = too_dirty_cutoff
|
| 121 |
-
self.minimum_slice_length = minimum_slice_length
|
| 122 |
-
self.delete_chars = delete_chars
|
| 123 |
-
|
| 124 |
-
self.dirt_ngrams = set()
|
| 125 |
-
|
| 126 |
-
# If in python, we'll translate uppercase to lowercase and delete naughty characters.
|
| 127 |
-
# This is fast by python standards
|
| 128 |
-
# https://stackoverflow.com/questions/638893/what-is-the-most-efficient-way-in-python-to-convert-a-string-to-all-lowercase-st
|
| 129 |
-
self.translation_table = str.maketrans(
|
| 130 |
-
string.ascii_lowercase + string.ascii_uppercase, # These characters
|
| 131 |
-
string.ascii_lowercase * 2, # Become these characters
|
| 132 |
-
self.delete_chars, # These are deleted
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
##############
|
| 136 |
-
# I/O for saving contamination ngrams
|
| 137 |
-
##############
|
| 138 |
-
|
| 139 |
-
def save_contamination_ngrams(self, filename: str) -> None:
|
| 140 |
-
with open(filename, "wb") as fp:
|
| 141 |
-
pickle.dump(filename, fp)
|
| 142 |
-
|
| 143 |
-
def load_contamination_ngrams(self, filename: str) -> None:
|
| 144 |
-
with open(filename, "rb") as fp:
|
| 145 |
-
self.dirt_ngrams = pickle.load(fp)
|
| 146 |
-
|
| 147 |
-
##############
|
| 148 |
-
# Call these :)
|
| 149 |
-
##############
|
| 150 |
-
|
| 151 |
-
def register_contaminant(self, dirt_string: str) -> None:
|
| 152 |
-
"""Register a string as contamination to be removed, e.g. a test set
|
| 153 |
-
This breaks the dirt_string into ngrams to store for future cleaning"""
|
| 154 |
-
if JANITOR_CPP:
|
| 155 |
-
return self.register_contaminant_cpp(dirt_string)
|
| 156 |
-
else:
|
| 157 |
-
print("WARNING: Janitor running in python mode")
|
| 158 |
-
return self.register_contaminant_python(dirt_string)
|
| 159 |
-
|
| 160 |
-
def clean(self, dirty_string: str) -> List[str]:
|
| 161 |
-
"""Clean a string (e.g. a training set) by removing all ngrams previously
|
| 162 |
-
registered as contaminants. Returns a list of clean chunks, or empty if
|
| 163 |
-
the string was too dirty"""
|
| 164 |
-
if JANITOR_CPP:
|
| 165 |
-
return self.clean_cpp(dirty_string)
|
| 166 |
-
else:
|
| 167 |
-
print("WARNING: Janitor running in python mode")
|
| 168 |
-
return self.clean_python(dirty_string)
|
| 169 |
-
|
| 170 |
-
def _split_chunks(
|
| 171 |
-
self, dirty_string: str, dirty_parts: Sequence[Tuple]
|
| 172 |
-
) -> List[str]:
|
| 173 |
-
clean_chunks = []
|
| 174 |
-
splice_idx = 0
|
| 175 |
-
end = -1
|
| 176 |
-
for i, (ngram, start, end) in enumerate(dirty_parts):
|
| 177 |
-
if i >= self.too_dirty_cutoff:
|
| 178 |
-
return []
|
| 179 |
-
start = max(0, start - self.window_to_remove)
|
| 180 |
-
end = min(len(dirty_string), end + self.window_to_remove)
|
| 181 |
-
|
| 182 |
-
if start - splice_idx > self.minimum_slice_length:
|
| 183 |
-
clean_chunks.append(dirty_string[splice_idx:start])
|
| 184 |
-
splice_idx = end
|
| 185 |
-
|
| 186 |
-
if end < len(dirty_string) - self.minimum_slice_length:
|
| 187 |
-
clean_chunks.append(dirty_string[end + 1 :])
|
| 188 |
-
|
| 189 |
-
return clean_chunks
|
| 190 |
-
|
| 191 |
-
##############
|
| 192 |
-
# Fast C++
|
| 193 |
-
##############
|
| 194 |
-
|
| 195 |
-
def register_contaminant_cpp(self, dirt_string) -> None:
|
| 196 |
-
self.dirt_ngrams.update(
|
| 197 |
-
janitor_util.clean_ngram(dirt_string, self.delete_chars, self.ngram_n)
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
def clean_cpp(self, dirty_string: str) -> List[str]:
|
| 201 |
-
contamination_indices = janitor_util.clean_ngram_with_indices(
|
| 202 |
-
dirty_string, self.delete_chars, self.ngram_n
|
| 203 |
-
)
|
| 204 |
-
return self._split_chunks(dirty_string, contamination_indices)
|
| 205 |
-
|
| 206 |
-
##############
|
| 207 |
-
# Slow python
|
| 208 |
-
##############
|
| 209 |
-
|
| 210 |
-
def normalize_string(self, s: str) -> str:
|
| 211 |
-
return s.translate(self.translation_table)
|
| 212 |
-
|
| 213 |
-
def register_contaminant_python(self, dirt_string: str) -> None:
|
| 214 |
-
self.dirt_ngrams.update(
|
| 215 |
-
word_ngrams(self.normalize_string(dirt_string), self.ngram_n)
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
def clean_python(self, dirty_string: str) -> List[str]:
|
| 219 |
-
contamination_indices = (
|
| 220 |
-
(None, *idx_pair)
|
| 221 |
-
for dirty_ngram, idx_pair in word_ngrams_indices(dirty_string, self.ngram_n)
|
| 222 |
-
if self.normalize_string(dirty_ngram) in self.dirt_ngrams
|
| 223 |
-
)
|
| 224 |
-
return self._split_chunks(dirty_string, contamination_indices)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
##################################################################
|
| 228 |
-
# Tests
|
| 229 |
-
#################################################################
|
| 230 |
-
|
| 231 |
-
# def print_cpp():
|
| 232 |
-
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 233 |
-
|
| 234 |
-
# for i in range(1, 10, 2):
|
| 235 |
-
# pprint(janitor_util.clean_ngram(source, string.punctuation, i))
|
| 236 |
-
# for ngram, start, end in \
|
| 237 |
-
# janitor_util.clean_ngram_with_indices(source, string.punctuation, i):
|
| 238 |
-
# print(ngram, "\t", start, end, source[start:end].replace("\n", "\\n"))
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
# def test_cpp():
|
| 242 |
-
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 243 |
-
# contaminant = "dirty boy. Clean he he"
|
| 244 |
-
|
| 245 |
-
# jan_python = Janitor()
|
| 246 |
-
# jan_cpp = Janitor()
|
| 247 |
-
|
| 248 |
-
# jan_python.register_contaminant_python(contaminant)
|
| 249 |
-
# jan_cpp.register_contaminant(contaminant)
|
| 250 |
-
|
| 251 |
-
# assert jan_python.dirt_ngrams == jan_cpp.dirt_ngrams, (jan_python.dirt_ngrams, jan_cpp.dirt_ngrams)
|
| 252 |
-
|
| 253 |
-
# assert jan_python.clean_python(source) == jan_cpp.clean(source), \
|
| 254 |
-
# (jan_python.clean_python(source), jan_cpp.clean(source))
|
| 255 |
-
|
| 256 |
-
# print("Passed test, python==cpp")
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
# def benchmark():
|
| 260 |
-
# # Download and put in data folder: enwik8 (100 MB) from https://cs.fit.edu/~mmahoney/compression/textdata.html
|
| 261 |
-
# setup = \
|
| 262 |
-
# """
|
| 263 |
-
# with open("data/enwik8", "r") as f:
|
| 264 |
-
# data = f.read()
|
| 265 |
-
# jan = Janitor(too_dirty_cutoff=1000)
|
| 266 |
-
# jan.register_contaminant('''
|
| 267 |
-
# theories is that there is a connection between "geekdom" and autism.
|
| 268 |
-
# This is hinted, for instance, by a ''Wired Magazine'' article in 2001 entitled "
|
| 269 |
-
# The [[Geek]] Syndrome", which is a point argued by many in the autism rights
|
| 270 |
-
# movement{{ref|Wired}}. This article, many professionals assert, is just one example of
|
| 271 |
-
# the media's application of mental disease labels to what is actually variant normal behavior
|
| 272 |
-
# &mdash;they argue that shyness, lack of athletic ability or social skills, and intellectual
|
| 273 |
-
# interests, even when they seem unusual to others, are not in themselves signs of autism or
|
| 274 |
-
# Asperger's syndrome. Others assert that it is actually the medical profession which is applying
|
| 275 |
-
# mental disease labels to children who in the past would have simply been accepted as a little
|
| 276 |
-
# different or even labeled 'gifted'. See [[clinomorphism]] for further discussion of this issue.
|
| 277 |
-
# Due to the recent publicity surrounding autism and autis
|
| 278 |
-
# ultan Al Nahyan]] granted [[Petroleum]] concessions, and oil was first found in 1958. At first,
|
| 279 |
-
# oil money had a marginal impact. A few lowrise concete buildings were erected, and the first
|
| 280 |
-
# paved road was completed in 1961, but Sheikh Shakbut, uncertain whether the new oil royalties
|
| 281 |
-
# would last, took a cautious approach, preferring to save the revenue rather than investing it in
|
| 282 |
-
# development. His brother, [[Zayed bin Sultan Al Nahayan]], saw that oil wealth had the potential
|
| 283 |
-
# to transform Abu Dhabi. The ruling Al Nahayan family decided that Sheikh Zayed should replace his
|
| 284 |
-
# brother as Ruler and carry out his vision of developing the country. On [[August 6]], [[1966]],
|
| 285 |
-
# with the assistance of the British, Sheikh Zayed became the new ruler. See generally, Al-Fahim, M,
|
| 286 |
-
# ''From Rags to Riches: A Story of Abu Dhabi'', Chapter Six (London Centre of Arab Studies, 1995),
|
| 287 |
-
# ISBN 1 900404 00 1. With the announcement by Britain in 1968 that it would withdraw from the
|
| 288 |
-
# Gulf area by 1971, Sheikh Zayed became the main driving force behind the formation of the
|
| 289 |
-
# [[United Arab Emirates]]. After the Emirates gained independence in 1971,
|
| 290 |
-
# ''')
|
| 291 |
-
# """
|
| 292 |
-
|
| 293 |
-
# n = 1
|
| 294 |
-
# print(f"Timing {n} run on 100 MB")
|
| 295 |
-
# print("Register contaminant")
|
| 296 |
-
# # print("\tPython", timeit.timeit("jan.register_contaminant_python(data)", setup=setup, globals=globals(), number=n))
|
| 297 |
-
# print("\tCpp", timeit.timeit("jan.register_contaminant(data)", setup=setup, globals=globals(), number=n))
|
| 298 |
-
|
| 299 |
-
# print("Clean")
|
| 300 |
-
# # print("\tPython", timeit.timeit("jan.clean_python(data)", setup=setup, globals=globals(), number=n))
|
| 301 |
-
# print("\tCpp", timeit.timeit("jan.clean(data)", setup=setup, globals=globals(), number=n))
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
# def test_janitor_general():
|
| 305 |
-
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
|
| 306 |
-
# contaminant = "dirty boy. Clean he he"
|
| 307 |
-
|
| 308 |
-
# jan = Janitor(ngram_n=3)
|
| 309 |
-
# jan.register_contaminant(contaminant)
|
| 310 |
-
# cleaned = " ".join(jan.clean(source))
|
| 311 |
-
# for contam in jan.dirt_ngrams:
|
| 312 |
-
# assert contam not in cleaned, contam
|
| 313 |
-
|
| 314 |
-
# filename = "data/saved_contam"
|
| 315 |
-
# jan.save_contamination_ngrams(filename)
|
| 316 |
-
|
| 317 |
-
# jan = Janitor(ngram_n=3)
|
| 318 |
-
# jan.load_contamination_ngrams(filename)
|
| 319 |
-
# cleaned = " ".join(jan.clean(source))
|
| 320 |
-
# for contam in jan.dirt_ngrams:
|
| 321 |
-
# assert contam not in cleaned, contam
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
# if __name__ == "__main__":
|
| 325 |
-
# test()
|
| 326 |
-
# # print_cpp()
|
| 327 |
-
# # test_cpp()
|
| 328 |
-
# # benchmark()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/evaluator.py
DELETED
|
@@ -1,761 +0,0 @@
|
|
| 1 |
-
import itertools
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
import random
|
| 5 |
-
import time
|
| 6 |
-
from collections import defaultdict
|
| 7 |
-
from typing import TYPE_CHECKING, List, Optional, Union
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import torch
|
| 11 |
-
|
| 12 |
-
import lm_eval.api.metrics
|
| 13 |
-
import lm_eval.api.registry
|
| 14 |
-
import lm_eval.api.task
|
| 15 |
-
import lm_eval.models
|
| 16 |
-
from lm_eval.caching.cache import delete_cache
|
| 17 |
-
from lm_eval.evaluator_utils import (
|
| 18 |
-
consolidate_group_results,
|
| 19 |
-
consolidate_results,
|
| 20 |
-
get_sample_size,
|
| 21 |
-
get_subtask_list,
|
| 22 |
-
get_task_list,
|
| 23 |
-
prepare_print_tasks,
|
| 24 |
-
print_writeout,
|
| 25 |
-
run_task_tests,
|
| 26 |
-
)
|
| 27 |
-
from lm_eval.loggers import EvaluationTracker
|
| 28 |
-
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
|
| 29 |
-
from lm_eval.tasks import TaskManager, get_task_dict
|
| 30 |
-
from lm_eval.utils import (
|
| 31 |
-
handle_non_serializable,
|
| 32 |
-
hash_string,
|
| 33 |
-
positional_deprecated,
|
| 34 |
-
setup_logging,
|
| 35 |
-
simple_parse_args_string,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
if TYPE_CHECKING:
|
| 40 |
-
from lm_eval.api.model import LM
|
| 41 |
-
from lm_eval.api.task import Task
|
| 42 |
-
|
| 43 |
-
eval_logger = logging.getLogger(__name__)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
@positional_deprecated
|
| 47 |
-
def simple_evaluate(
|
| 48 |
-
model,
|
| 49 |
-
model_args: Optional[Union[str, dict]] = None,
|
| 50 |
-
tasks: Optional[List[Union[str, dict, object]]] = None,
|
| 51 |
-
num_fewshot: Optional[int] = None,
|
| 52 |
-
batch_size: Optional[Union[int, str]] = None,
|
| 53 |
-
max_batch_size: Optional[int] = None,
|
| 54 |
-
device: Optional[str] = None,
|
| 55 |
-
use_cache: Optional[str] = None,
|
| 56 |
-
cache_requests: bool = False,
|
| 57 |
-
rewrite_requests_cache: bool = False,
|
| 58 |
-
delete_requests_cache: bool = False,
|
| 59 |
-
limit: Optional[Union[int, float]] = None,
|
| 60 |
-
samples: Optional[dict] = None,
|
| 61 |
-
bootstrap_iters: int = 100000,
|
| 62 |
-
check_integrity: bool = False,
|
| 63 |
-
write_out: bool = False,
|
| 64 |
-
log_samples: bool = True,
|
| 65 |
-
evaluation_tracker: Optional[EvaluationTracker] = None,
|
| 66 |
-
system_instruction: Optional[str] = None,
|
| 67 |
-
apply_chat_template: Union[bool, str] = False,
|
| 68 |
-
fewshot_as_multiturn: bool = False,
|
| 69 |
-
gen_kwargs: Union[str, dict, None] = None,
|
| 70 |
-
task_manager: Optional[TaskManager] = None,
|
| 71 |
-
verbosity=None,
|
| 72 |
-
predict_only: bool = False,
|
| 73 |
-
random_seed: int = 0,
|
| 74 |
-
numpy_random_seed: int = 1234,
|
| 75 |
-
torch_random_seed: int = 1234,
|
| 76 |
-
fewshot_random_seed: int = 1234,
|
| 77 |
-
confirm_run_unsafe_code: bool = False,
|
| 78 |
-
metadata: Optional[dict] = None,
|
| 79 |
-
):
|
| 80 |
-
"""Instantiate and evaluate a model on a list of tasks.
|
| 81 |
-
|
| 82 |
-
:param model: Union[str, LM]
|
| 83 |
-
Name of model or LM object, see lm_eval.models.get_model
|
| 84 |
-
:param model_args: Optional[str, dict]
|
| 85 |
-
String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
|
| 86 |
-
Ignored if `model` argument is a LM object.
|
| 87 |
-
:param tasks: list[Union[str, dict, Task]]
|
| 88 |
-
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
| 89 |
-
:param num_fewshot: int
|
| 90 |
-
Number of examples in few-shot context
|
| 91 |
-
:param batch_size: int or str, optional
|
| 92 |
-
Batch size for model
|
| 93 |
-
:param max_batch_size: int, optional
|
| 94 |
-
Maximal batch size to try with automatic batch size detection
|
| 95 |
-
:param device: str, optional
|
| 96 |
-
PyTorch device (e.g. "cpu" or "cuda:0") for running models
|
| 97 |
-
:param use_cache: str, optional
|
| 98 |
-
A path to a sqlite db file for caching model responses. `None` if not caching.
|
| 99 |
-
:param cache_requests: bool, optional
|
| 100 |
-
Speed up evaluation by caching the building of dataset requests. `None` if not caching.
|
| 101 |
-
:param rewrite_requests_cache: bool, optional
|
| 102 |
-
Rewrites all the request cache if set to `True`. `None` if not desired.
|
| 103 |
-
:param delete_requests_cache: bool, optional
|
| 104 |
-
Deletes all the request cache if set to `True`. `None` if not desired.
|
| 105 |
-
:param limit: int or float, optional
|
| 106 |
-
Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
|
| 107 |
-
:param samples: dictionary, optional
|
| 108 |
-
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
|
| 109 |
-
:param bootstrap_iters:
|
| 110 |
-
Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
|
| 111 |
-
:param check_integrity: bool
|
| 112 |
-
Whether to run the relevant part of the test suite for the tasks
|
| 113 |
-
:param write_out: bool
|
| 114 |
-
If True, write out an example document and model input for checking task integrity
|
| 115 |
-
:param log_samples: bool
|
| 116 |
-
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 117 |
-
:param system_instruction: str
|
| 118 |
-
System instruction to be applied to the prompt
|
| 119 |
-
:param apply_chat_template: Union[bool, str]
|
| 120 |
-
Specifies whether to apply a chat template to the prompt.
|
| 121 |
-
- If set to True, the default chat template is applied.
|
| 122 |
-
- If set to a string, applies the specified chat template by name.
|
| 123 |
-
Defaults to False (no chat template applied).
|
| 124 |
-
:param fewshot_as_multiturn: bool
|
| 125 |
-
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 126 |
-
:param gen_kwargs: dict or comma-separated string
|
| 127 |
-
Arguments for model generation
|
| 128 |
-
Ignored for all tasks with loglikelihood output_type
|
| 129 |
-
:param verbosity: str
|
| 130 |
-
Verbosity level for logging
|
| 131 |
-
:param predict_only: bool
|
| 132 |
-
If true only model outputs will be generated and returned. Metrics will not be evaluated
|
| 133 |
-
:param random_seed: int
|
| 134 |
-
Random seed for python's random module. If set to None, the seed will not be set.
|
| 135 |
-
:param numpy_random_seed: int
|
| 136 |
-
Random seed for numpy. If set to None, the seed will not be set.
|
| 137 |
-
:param torch_random_seed: int
|
| 138 |
-
Random seed for torch. If set to None, the seed will not be set.
|
| 139 |
-
:param fewshot_random_seed: int
|
| 140 |
-
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
|
| 141 |
-
:param metadata: dict
|
| 142 |
-
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
|
| 143 |
-
|
| 144 |
-
return
|
| 145 |
-
Dictionary of results
|
| 146 |
-
"""
|
| 147 |
-
if verbosity is not None:
|
| 148 |
-
setup_logging(verbosity=verbosity)
|
| 149 |
-
start_date = time.time()
|
| 150 |
-
|
| 151 |
-
if limit is not None and samples is not None:
|
| 152 |
-
raise ValueError(
|
| 153 |
-
"Either 'limit' or 'samples' must be None, but both are not None."
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
if isinstance(model_args, str) and (
|
| 157 |
-
"instruct" in model_args and not apply_chat_template
|
| 158 |
-
):
|
| 159 |
-
eval_logger.warning(
|
| 160 |
-
"Instruct model detected, but chat template not applied. Recommend setting `apply_chat_template` (optionally `fewshot_as_multiturn`)."
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
if delete_requests_cache:
|
| 164 |
-
eval_logger.info("Deleting requests cache...")
|
| 165 |
-
delete_cache()
|
| 166 |
-
|
| 167 |
-
seed_message = []
|
| 168 |
-
if random_seed is not None:
|
| 169 |
-
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
|
| 170 |
-
seed_message.append(f"Setting random seed to {random_seed}")
|
| 171 |
-
random.seed(random_seed)
|
| 172 |
-
|
| 173 |
-
if numpy_random_seed is not None:
|
| 174 |
-
seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
|
| 175 |
-
np.random.seed(numpy_random_seed)
|
| 176 |
-
|
| 177 |
-
if torch_random_seed is not None:
|
| 178 |
-
seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
|
| 179 |
-
torch.manual_seed(torch_random_seed)
|
| 180 |
-
|
| 181 |
-
if fewshot_random_seed is not None:
|
| 182 |
-
seed_message.append(f"Setting fewshot manual seed to {fewshot_random_seed}")
|
| 183 |
-
|
| 184 |
-
if seed_message:
|
| 185 |
-
eval_logger.info(" | ".join(seed_message))
|
| 186 |
-
|
| 187 |
-
if tasks is None:
|
| 188 |
-
tasks = []
|
| 189 |
-
if len(tasks) == 0:
|
| 190 |
-
raise ValueError(
|
| 191 |
-
"No tasks specified, or no tasks found. Please verify the task names."
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
if gen_kwargs is not None:
|
| 195 |
-
if isinstance(gen_kwargs, str):
|
| 196 |
-
gen_kwargs = simple_parse_args_string(gen_kwargs)
|
| 197 |
-
eval_logger.warning(
|
| 198 |
-
f"generation_kwargs: {gen_kwargs} specified through cli, these settings will update set parameters in yaml tasks. "
|
| 199 |
-
"Ensure 'do_sample=True' for non-greedy decoding!"
|
| 200 |
-
)
|
| 201 |
-
if not gen_kwargs:
|
| 202 |
-
gen_kwargs = None
|
| 203 |
-
|
| 204 |
-
if isinstance(model, str):
|
| 205 |
-
if model_args is None:
|
| 206 |
-
eval_logger.warning("model_args not specified. Using defaults.")
|
| 207 |
-
model_args = ""
|
| 208 |
-
|
| 209 |
-
if isinstance(model_args, dict):
|
| 210 |
-
eval_logger.info(
|
| 211 |
-
f"Initializing {model} model, with arguments: {model_args}"
|
| 212 |
-
)
|
| 213 |
-
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
|
| 214 |
-
model_args,
|
| 215 |
-
{
|
| 216 |
-
"batch_size": batch_size,
|
| 217 |
-
"max_batch_size": max_batch_size,
|
| 218 |
-
"device": device,
|
| 219 |
-
},
|
| 220 |
-
)
|
| 221 |
-
|
| 222 |
-
else:
|
| 223 |
-
eval_logger.info(
|
| 224 |
-
f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
|
| 225 |
-
)
|
| 226 |
-
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
|
| 227 |
-
model_args,
|
| 228 |
-
{
|
| 229 |
-
"batch_size": batch_size,
|
| 230 |
-
"max_batch_size": max_batch_size,
|
| 231 |
-
"device": device,
|
| 232 |
-
},
|
| 233 |
-
)
|
| 234 |
-
else:
|
| 235 |
-
if not isinstance(model, lm_eval.api.model.LM):
|
| 236 |
-
raise TypeError(
|
| 237 |
-
f"The value of `model` passed to simple_evaluate() was of type {type(model)}, but is required to be a subclass of lm_eval.api.model.LM . This may be because you are passing an initialized Hugging Face PreTrainedModel without having wrapped it in `lm_eval.models.huggingface.HFLM(pretrained=my_model)` first."
|
| 238 |
-
)
|
| 239 |
-
eval_logger.info("Using pre-initialized model")
|
| 240 |
-
lm = model
|
| 241 |
-
|
| 242 |
-
if use_cache is not None:
|
| 243 |
-
eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
|
| 244 |
-
lm = lm_eval.api.model.CachingLM(
|
| 245 |
-
lm,
|
| 246 |
-
use_cache
|
| 247 |
-
# each rank receives a different cache db.
|
| 248 |
-
# necessary to avoid multiple writes to cache at once
|
| 249 |
-
+ "_rank"
|
| 250 |
-
+ str(lm.rank)
|
| 251 |
-
+ ".db",
|
| 252 |
-
)
|
| 253 |
-
|
| 254 |
-
if task_manager is None:
|
| 255 |
-
metadata = (
|
| 256 |
-
simple_parse_args_string(model_args)
|
| 257 |
-
if isinstance(model_args, str)
|
| 258 |
-
else model_args
|
| 259 |
-
if isinstance(model_args, dict)
|
| 260 |
-
else {}
|
| 261 |
-
) | (metadata or {})
|
| 262 |
-
task_manager = TaskManager(metadata=metadata)
|
| 263 |
-
|
| 264 |
-
task_dict = get_task_dict(
|
| 265 |
-
tasks,
|
| 266 |
-
task_manager,
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
|
| 270 |
-
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
|
| 271 |
-
def _adjust_config(task_dict):
|
| 272 |
-
adjusted_task_dict = {}
|
| 273 |
-
for task_name, task_obj in task_dict.items():
|
| 274 |
-
if isinstance(task_obj, dict):
|
| 275 |
-
adjusted_task_dict = {
|
| 276 |
-
**adjusted_task_dict,
|
| 277 |
-
**{task_name: _adjust_config(task_obj)},
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
else:
|
| 281 |
-
if task_obj.get_config("output_type") == "generate_until":
|
| 282 |
-
if gen_kwargs is not None:
|
| 283 |
-
task_obj.set_config(
|
| 284 |
-
key="generation_kwargs", value=gen_kwargs, update=True
|
| 285 |
-
)
|
| 286 |
-
eval_logger.info(
|
| 287 |
-
f"{task_obj.config.task}: Using gen_kwargs: {task_obj.config.generation_kwargs}"
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
if predict_only:
|
| 291 |
-
eval_logger.info(
|
| 292 |
-
f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
|
| 293 |
-
)
|
| 294 |
-
# we have to change the class properties post-hoc. This is pretty hacky.
|
| 295 |
-
task_obj.override_metric(metric_name="bypass")
|
| 296 |
-
|
| 297 |
-
# override tasks' fewshot values to the provided num_fewshot arg value
|
| 298 |
-
# except if tasks have it set to 0 manually in their configs--then we should never overwrite that
|
| 299 |
-
if num_fewshot is not None:
|
| 300 |
-
if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
|
| 301 |
-
eval_logger.info(
|
| 302 |
-
f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
|
| 303 |
-
)
|
| 304 |
-
else:
|
| 305 |
-
eval_logger.warning(
|
| 306 |
-
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
|
| 307 |
-
)
|
| 308 |
-
task_obj.set_config(key="num_fewshot", value=num_fewshot)
|
| 309 |
-
else:
|
| 310 |
-
# if num_fewshot not provided, and the task does not define a default one, default to 0
|
| 311 |
-
if (
|
| 312 |
-
default_num_fewshot := task_obj.get_config("num_fewshot")
|
| 313 |
-
) is None:
|
| 314 |
-
task_obj.set_config(key="num_fewshot", value=0)
|
| 315 |
-
# fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
|
| 316 |
-
task_obj.set_fewshot_seed(seed=fewshot_random_seed)
|
| 317 |
-
|
| 318 |
-
adjusted_task_dict[task_name] = task_obj
|
| 319 |
-
|
| 320 |
-
return adjusted_task_dict
|
| 321 |
-
|
| 322 |
-
task_dict = _adjust_config(task_dict)
|
| 323 |
-
|
| 324 |
-
if check_integrity:
|
| 325 |
-
run_task_tests(task_list=tasks)
|
| 326 |
-
|
| 327 |
-
if evaluation_tracker is not None:
|
| 328 |
-
evaluation_tracker.general_config_tracker.log_experiment_args(
|
| 329 |
-
model_source=model,
|
| 330 |
-
model_args=model_args,
|
| 331 |
-
system_instruction=system_instruction,
|
| 332 |
-
chat_template=lm.chat_template(apply_chat_template)
|
| 333 |
-
if apply_chat_template
|
| 334 |
-
else None,
|
| 335 |
-
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
results = evaluate(
|
| 339 |
-
lm=lm,
|
| 340 |
-
task_dict=task_dict,
|
| 341 |
-
limit=limit,
|
| 342 |
-
samples=samples,
|
| 343 |
-
cache_requests=cache_requests,
|
| 344 |
-
rewrite_requests_cache=rewrite_requests_cache,
|
| 345 |
-
bootstrap_iters=bootstrap_iters,
|
| 346 |
-
write_out=write_out,
|
| 347 |
-
log_samples=True if predict_only else log_samples,
|
| 348 |
-
system_instruction=system_instruction,
|
| 349 |
-
apply_chat_template=apply_chat_template,
|
| 350 |
-
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 351 |
-
verbosity=verbosity,
|
| 352 |
-
confirm_run_unsafe_code=confirm_run_unsafe_code,
|
| 353 |
-
)
|
| 354 |
-
if verbosity is not None:
|
| 355 |
-
setup_logging(verbosity=verbosity)
|
| 356 |
-
|
| 357 |
-
if lm.rank == 0:
|
| 358 |
-
if isinstance(model, str):
|
| 359 |
-
model_name = model
|
| 360 |
-
elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
|
| 361 |
-
model_name = model.config._name_or_path
|
| 362 |
-
else:
|
| 363 |
-
model_name = type(model).__name__
|
| 364 |
-
|
| 365 |
-
# add info about the model and few shot config
|
| 366 |
-
results["config"] = {
|
| 367 |
-
"model": model_name,
|
| 368 |
-
"model_args": model_args,
|
| 369 |
-
}
|
| 370 |
-
# add more detailed model info if available
|
| 371 |
-
if isinstance(lm, lm_eval.models.huggingface.HFLM):
|
| 372 |
-
results["config"].update(lm.get_model_info())
|
| 373 |
-
# add info about execution
|
| 374 |
-
results["config"].update(
|
| 375 |
-
{
|
| 376 |
-
"batch_size": batch_size,
|
| 377 |
-
"batch_sizes": (
|
| 378 |
-
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
|
| 379 |
-
),
|
| 380 |
-
"device": device,
|
| 381 |
-
"use_cache": use_cache,
|
| 382 |
-
"limit": limit,
|
| 383 |
-
"bootstrap_iters": bootstrap_iters,
|
| 384 |
-
"gen_kwargs": gen_kwargs,
|
| 385 |
-
"random_seed": random_seed,
|
| 386 |
-
"numpy_seed": numpy_random_seed,
|
| 387 |
-
"torch_seed": torch_random_seed,
|
| 388 |
-
"fewshot_seed": fewshot_random_seed,
|
| 389 |
-
}
|
| 390 |
-
)
|
| 391 |
-
results["git_hash"] = get_git_commit_hash()
|
| 392 |
-
results["date"] = start_date
|
| 393 |
-
add_env_info(results) # additional environment info to results
|
| 394 |
-
add_tokenizer_info(results, lm) # additional info about tokenizer
|
| 395 |
-
return results
|
| 396 |
-
else:
|
| 397 |
-
return None
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
@positional_deprecated
|
| 401 |
-
def evaluate(
|
| 402 |
-
lm: "LM",
|
| 403 |
-
task_dict,
|
| 404 |
-
limit: Optional[int] = None,
|
| 405 |
-
samples: Optional[dict] = None,
|
| 406 |
-
cache_requests: bool = False,
|
| 407 |
-
rewrite_requests_cache: bool = False,
|
| 408 |
-
bootstrap_iters: Optional[int] = 100000,
|
| 409 |
-
write_out: bool = False,
|
| 410 |
-
log_samples: bool = True,
|
| 411 |
-
system_instruction: Optional[str] = None,
|
| 412 |
-
apply_chat_template: Union[bool, str] = False,
|
| 413 |
-
fewshot_as_multiturn: bool = False,
|
| 414 |
-
verbosity: str = "INFO",
|
| 415 |
-
confirm_run_unsafe_code: bool = False,
|
| 416 |
-
):
|
| 417 |
-
"""Instantiate and evaluate a model on a list of tasks.
|
| 418 |
-
|
| 419 |
-
:param lm: obj
|
| 420 |
-
Language Model
|
| 421 |
-
:param task_dict: dict[str, Task]
|
| 422 |
-
Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
|
| 423 |
-
:param limit: int, optional
|
| 424 |
-
Limit the number of examples per task (only use this for testing)
|
| 425 |
-
:param samples: dictionary, optional
|
| 426 |
-
Dictionary indicating which examples should be tested in each task, e.g., {"mmlu_astronomy":[0,3,6],"mmlu_anatomy":[1,4,7,10]}.
|
| 427 |
-
:param cache_requests: bool, optional
|
| 428 |
-
Speed up evaluation by caching the building of dataset requests.
|
| 429 |
-
:param rewrite_requests_cache: bool, optional
|
| 430 |
-
Rewrites all the request cache if set to `True`.
|
| 431 |
-
:param bootstrap_iters:
|
| 432 |
-
Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
|
| 433 |
-
:param write_out: bool
|
| 434 |
-
If True, write out an example document and model input for checking task integrity
|
| 435 |
-
:param log_samples: bool
|
| 436 |
-
If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
|
| 437 |
-
:param system_instruction: str
|
| 438 |
-
System instruction to be applied to the prompt
|
| 439 |
-
:param apply_chat_template: Union[bool, str]
|
| 440 |
-
Specifies whether to apply a chat template to the prompt.
|
| 441 |
-
- If set to True, the default chat template is applied.
|
| 442 |
-
- If set to a string, applies the specified chat template by name.
|
| 443 |
-
Defaults to False (no chat template applied).
|
| 444 |
-
:param fewshot_as_multiturn: bool
|
| 445 |
-
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
|
| 446 |
-
:param verbosity: str
|
| 447 |
-
Verbosity level for logging
|
| 448 |
-
:param confirm_run_unsafe_code: bool
|
| 449 |
-
Whether to confirm running tasks marked as unsafe.
|
| 450 |
-
:return
|
| 451 |
-
Dictionary of results
|
| 452 |
-
"""
|
| 453 |
-
|
| 454 |
-
if limit is not None and samples is not None:
|
| 455 |
-
raise ValueError(
|
| 456 |
-
"Either 'limit' or 'samples' must be None, but both are not None."
|
| 457 |
-
)
|
| 458 |
-
if samples is not None:
|
| 459 |
-
eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
|
| 460 |
-
if apply_chat_template:
|
| 461 |
-
eval_logger.warning(
|
| 462 |
-
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
|
| 463 |
-
)
|
| 464 |
-
# tracks all Instances/requests a model must generate output on.
|
| 465 |
-
requests = defaultdict(list)
|
| 466 |
-
# stores the amount to pad out reqs per req. type so that
|
| 467 |
-
# number of fwd passes per distributed rank is equal
|
| 468 |
-
padding_requests = defaultdict(int)
|
| 469 |
-
|
| 470 |
-
# get lists of group hierarchy and each type of request
|
| 471 |
-
eval_tasks = get_task_list(task_dict)
|
| 472 |
-
if not log_samples:
|
| 473 |
-
if not all(
|
| 474 |
-
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
|
| 475 |
-
for task_output in eval_tasks
|
| 476 |
-
):
|
| 477 |
-
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
|
| 478 |
-
|
| 479 |
-
# validation checks:
|
| 480 |
-
# 1.are we running multimodal task <-> non-multimodal model class, or vice-versa.
|
| 481 |
-
# 2.are we running code that is marked as unsafe.
|
| 482 |
-
incompatible_tasks = []
|
| 483 |
-
for task_output in eval_tasks:
|
| 484 |
-
task: Task = task_output.task
|
| 485 |
-
|
| 486 |
-
if getattr(task, "MULTIMODAL", False) and not getattr(lm, "MULTIMODAL", False):
|
| 487 |
-
incompatible_tasks.append(task_output.task_name)
|
| 488 |
-
elif getattr(task, "UNSAFE_CODE", False) and not confirm_run_unsafe_code:
|
| 489 |
-
raise ValueError(
|
| 490 |
-
f"Attempted to run task: {task_output.task_name} which is marked as unsafe. Set confirm_run_unsafe_code=True to run this task."
|
| 491 |
-
)
|
| 492 |
-
if len(incompatible_tasks) > 0:
|
| 493 |
-
if not getattr(lm, "MULTIMODAL", False):
|
| 494 |
-
raise ValueError(
|
| 495 |
-
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
|
| 496 |
-
)
|
| 497 |
-
# end validation check
|
| 498 |
-
|
| 499 |
-
# Cache the limit arg.
|
| 500 |
-
limit_arg = limit
|
| 501 |
-
limits = []
|
| 502 |
-
for task_output in eval_tasks:
|
| 503 |
-
task: Task = task_output.task
|
| 504 |
-
|
| 505 |
-
limit = get_sample_size(task, limit_arg)
|
| 506 |
-
limits.append(limit)
|
| 507 |
-
task.build_all_requests(
|
| 508 |
-
limit=limit,
|
| 509 |
-
samples=samples.get(task_output.task_name, None)
|
| 510 |
-
if samples is not None
|
| 511 |
-
else samples,
|
| 512 |
-
rank=lm.rank,
|
| 513 |
-
world_size=lm.world_size,
|
| 514 |
-
cache_requests=cache_requests,
|
| 515 |
-
rewrite_requests_cache=rewrite_requests_cache,
|
| 516 |
-
system_instruction=system_instruction,
|
| 517 |
-
apply_chat_template=bool(apply_chat_template),
|
| 518 |
-
fewshot_as_multiturn=fewshot_as_multiturn,
|
| 519 |
-
chat_template=getattr(lm, "apply_chat_template")
|
| 520 |
-
if apply_chat_template
|
| 521 |
-
else None,
|
| 522 |
-
tokenizer_name=getattr(lm, "tokenizer_name", "")
|
| 523 |
-
if apply_chat_template
|
| 524 |
-
else "",
|
| 525 |
-
)
|
| 526 |
-
eval_logger.debug(
|
| 527 |
-
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
|
| 528 |
-
)
|
| 529 |
-
if write_out:
|
| 530 |
-
print_writeout(task)
|
| 531 |
-
# aggregate Instances by LM method requested to get output.
|
| 532 |
-
for instance in task.instances:
|
| 533 |
-
reqtype = instance.request_type
|
| 534 |
-
requests[reqtype].append(instance)
|
| 535 |
-
|
| 536 |
-
if lm.world_size > 1:
|
| 537 |
-
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
|
| 538 |
-
gathered_item = (
|
| 539 |
-
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
|
| 540 |
-
)
|
| 541 |
-
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
|
| 542 |
-
reqtype = (
|
| 543 |
-
"loglikelihood"
|
| 544 |
-
if task.OUTPUT_TYPE == "multiple_choice"
|
| 545 |
-
else task.OUTPUT_TYPE
|
| 546 |
-
)
|
| 547 |
-
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
|
| 548 |
-
numpad = max(gathered_item) - gathered_item[lm.rank]
|
| 549 |
-
# todo: may not account for padding in cases like SquadV2 which has multiple req types
|
| 550 |
-
padding_requests[reqtype] += numpad
|
| 551 |
-
|
| 552 |
-
### Run LM on inputs, get all outputs ###
|
| 553 |
-
# execute each type of request
|
| 554 |
-
for reqtype, reqs in requests.items():
|
| 555 |
-
eval_logger.info(f"Running {reqtype} requests")
|
| 556 |
-
# create `K` copies of each request `req` based off `K = req.repeats`
|
| 557 |
-
cloned_reqs = []
|
| 558 |
-
for req in reqs:
|
| 559 |
-
cloned_reqs.extend([req] * req.repeats)
|
| 560 |
-
|
| 561 |
-
if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
|
| 562 |
-
for _ in range(padding_requests[reqtype]):
|
| 563 |
-
cloned_reqs.extend([req] * req.repeats)
|
| 564 |
-
|
| 565 |
-
# run requests through model
|
| 566 |
-
resps = getattr(lm, reqtype)(cloned_reqs)
|
| 567 |
-
|
| 568 |
-
# put responses from model into a list of length K for each request.
|
| 569 |
-
for x, req in zip(resps, cloned_reqs):
|
| 570 |
-
req.resps.append(x)
|
| 571 |
-
|
| 572 |
-
if lm.world_size > 1:
|
| 573 |
-
lm.accelerator.wait_for_everyone()
|
| 574 |
-
|
| 575 |
-
RANK = lm.rank
|
| 576 |
-
WORLD_SIZE = lm.world_size
|
| 577 |
-
### Postprocess outputs ###
|
| 578 |
-
# TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
|
| 579 |
-
for task_output, limit in zip(eval_tasks, limits):
|
| 580 |
-
task = task_output.task
|
| 581 |
-
task.apply_filters()
|
| 582 |
-
|
| 583 |
-
### Collect values of metrics on all datapoints ###
|
| 584 |
-
# # unpack results and sort back in order and return control to Task
|
| 585 |
-
# TODO: make it possible to use a different metric per filter
|
| 586 |
-
# Pre-process task.instances to group by doc_id
|
| 587 |
-
instances_by_doc_id = defaultdict(list)
|
| 588 |
-
for instance in task.instances:
|
| 589 |
-
instances_by_doc_id[instance.doc_id].append(instance)
|
| 590 |
-
# Sort instances within each group
|
| 591 |
-
for instances in instances_by_doc_id.values():
|
| 592 |
-
instances.sort(key=lambda x: x.idx)
|
| 593 |
-
# iterate over different filters used
|
| 594 |
-
for filter_key in task.instances[0].filtered_resps.keys():
|
| 595 |
-
indices = (
|
| 596 |
-
samples.get(task_output.task_name, None)
|
| 597 |
-
if samples is not None
|
| 598 |
-
else None
|
| 599 |
-
)
|
| 600 |
-
doc_iterator = task.doc_iterator(
|
| 601 |
-
rank=RANK,
|
| 602 |
-
limit=limit,
|
| 603 |
-
world_size=WORLD_SIZE,
|
| 604 |
-
samples=indices,
|
| 605 |
-
)
|
| 606 |
-
for doc_id, doc in doc_iterator:
|
| 607 |
-
if indices:
|
| 608 |
-
doc_id_true = indices[doc_id]
|
| 609 |
-
else:
|
| 610 |
-
doc_id_true = doc_id
|
| 611 |
-
requests = instances_by_doc_id[doc_id]
|
| 612 |
-
metrics = task.process_results(
|
| 613 |
-
doc, [req.filtered_resps[filter_key] for req in requests]
|
| 614 |
-
)
|
| 615 |
-
if log_samples:
|
| 616 |
-
target = task.doc_to_target(doc)
|
| 617 |
-
example = {
|
| 618 |
-
"doc_id": doc_id_true,
|
| 619 |
-
"doc": doc,
|
| 620 |
-
"target": target,
|
| 621 |
-
"arguments": [req.args for req in requests],
|
| 622 |
-
"resps": [req.resps for req in requests],
|
| 623 |
-
"filtered_resps": [
|
| 624 |
-
req.filtered_resps[filter_key] for req in requests
|
| 625 |
-
],
|
| 626 |
-
"filter": filter_key,
|
| 627 |
-
"metrics": list(metrics.keys()),
|
| 628 |
-
"doc_hash": hash_string(
|
| 629 |
-
json.dumps(
|
| 630 |
-
requests[0].doc,
|
| 631 |
-
indent=2,
|
| 632 |
-
default=handle_non_serializable,
|
| 633 |
-
ensure_ascii=False,
|
| 634 |
-
)
|
| 635 |
-
),
|
| 636 |
-
"prompt_hash": hash_string(requests[0].arguments[0]),
|
| 637 |
-
"target_hash": hash_string(str(target)),
|
| 638 |
-
}
|
| 639 |
-
example.update(metrics)
|
| 640 |
-
task_output.logged_samples.append(example)
|
| 641 |
-
for metric, value in metrics.items():
|
| 642 |
-
task_output.sample_metrics[(metric, filter_key)].append(value)
|
| 643 |
-
|
| 644 |
-
if WORLD_SIZE > 1:
|
| 645 |
-
# if multigpu, then gather data across all ranks to rank 0
|
| 646 |
-
# first gather logged samples across all ranks
|
| 647 |
-
for task_output in eval_tasks:
|
| 648 |
-
if log_samples:
|
| 649 |
-
# for task_name, task_samples in list(samples.items()):
|
| 650 |
-
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
|
| 651 |
-
torch.distributed.gather_object(
|
| 652 |
-
obj=task_output.logged_samples,
|
| 653 |
-
object_gather_list=full_samples,
|
| 654 |
-
dst=0,
|
| 655 |
-
)
|
| 656 |
-
|
| 657 |
-
if RANK == 0:
|
| 658 |
-
task_output.logged_samples = list(
|
| 659 |
-
itertools.chain.from_iterable(full_samples)
|
| 660 |
-
)
|
| 661 |
-
|
| 662 |
-
# then collect metrics across all ranks
|
| 663 |
-
for metrics in task_output.sample_metrics:
|
| 664 |
-
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
|
| 665 |
-
torch.distributed.gather_object(
|
| 666 |
-
obj=task_output.sample_metrics[metrics],
|
| 667 |
-
object_gather_list=metric_list,
|
| 668 |
-
dst=0,
|
| 669 |
-
)
|
| 670 |
-
if RANK == 0:
|
| 671 |
-
task_output.sample_metrics[metrics] = list(
|
| 672 |
-
itertools.chain.from_iterable(metric_list)
|
| 673 |
-
)
|
| 674 |
-
|
| 675 |
-
if RANK == 0:
|
| 676 |
-
### Aggregate results over all datapoints ###
|
| 677 |
-
# aggregate results ; run bootstrap CIs
|
| 678 |
-
for task_output in eval_tasks:
|
| 679 |
-
task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
|
| 680 |
-
(
|
| 681 |
-
results,
|
| 682 |
-
samples,
|
| 683 |
-
configs,
|
| 684 |
-
versions,
|
| 685 |
-
num_fewshot,
|
| 686 |
-
higher_is_better,
|
| 687 |
-
) = consolidate_results(eval_tasks)
|
| 688 |
-
|
| 689 |
-
### Calculate group metrics ###
|
| 690 |
-
if bool(results):
|
| 691 |
-
results, versions, show_group_table, *_ = consolidate_group_results(
|
| 692 |
-
results, versions, task_dict
|
| 693 |
-
)
|
| 694 |
-
|
| 695 |
-
results_agg, group_agg = prepare_print_tasks(task_dict, results)
|
| 696 |
-
subtask_list = get_subtask_list(task_dict)
|
| 697 |
-
|
| 698 |
-
# collect all higher_is_better values for metrics
|
| 699 |
-
# in the group's subtasks.
|
| 700 |
-
# TODO: clean this up ; unify with the below metric_list loop?
|
| 701 |
-
_higher_is_better = {}
|
| 702 |
-
for group, task_list in subtask_list.items():
|
| 703 |
-
if (
|
| 704 |
-
len(task_list) != 0
|
| 705 |
-
): # subtask list will list "task_name": [] for solo tasks
|
| 706 |
-
for task in task_list:
|
| 707 |
-
for m, h in higher_is_better[task].items():
|
| 708 |
-
if m not in _higher_is_better.keys():
|
| 709 |
-
_higher_is_better[m] = h
|
| 710 |
-
|
| 711 |
-
if (
|
| 712 |
-
m in _higher_is_better
|
| 713 |
-
and _higher_is_better[m] is not None
|
| 714 |
-
and _higher_is_better[m] != h
|
| 715 |
-
):
|
| 716 |
-
eval_logger.warning(
|
| 717 |
-
f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
|
| 718 |
-
)
|
| 719 |
-
_higher_is_better[m] = None
|
| 720 |
-
higher_is_better[group] = _higher_is_better
|
| 721 |
-
|
| 722 |
-
results_dict = {
|
| 723 |
-
"results": dict(results_agg.items()),
|
| 724 |
-
**(
|
| 725 |
-
{"groups": dict(group_agg.items())}
|
| 726 |
-
if (bool(group_agg) & show_group_table)
|
| 727 |
-
else {}
|
| 728 |
-
),
|
| 729 |
-
"group_subtasks": dict(reversed(subtask_list.items())),
|
| 730 |
-
"configs": dict(sorted(configs.items())),
|
| 731 |
-
"versions": dict(sorted(versions.items())),
|
| 732 |
-
"n-shot": dict(sorted(num_fewshot.items())),
|
| 733 |
-
"higher_is_better": dict(sorted(higher_is_better.items())),
|
| 734 |
-
"n-samples": {
|
| 735 |
-
task_output.task_name: {
|
| 736 |
-
"original": len(task_output.task.eval_docs),
|
| 737 |
-
"effective": min(
|
| 738 |
-
limit if limit else len(task_output.task.eval_docs),
|
| 739 |
-
len(task_output.task.eval_docs),
|
| 740 |
-
),
|
| 741 |
-
}
|
| 742 |
-
for task_output, limit in zip(eval_tasks, limits)
|
| 743 |
-
},
|
| 744 |
-
}
|
| 745 |
-
if log_samples:
|
| 746 |
-
results_dict["samples"] = dict(samples)
|
| 747 |
-
|
| 748 |
-
return results_dict
|
| 749 |
-
|
| 750 |
-
else:
|
| 751 |
-
return None
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
def request_caching_arg_to_dict(cache_requests: str) -> dict:
|
| 755 |
-
request_caching_args = {
|
| 756 |
-
"cache_requests": cache_requests in {"true", "refresh"},
|
| 757 |
-
"rewrite_requests_cache": cache_requests == "refresh",
|
| 758 |
-
"delete_requests_cache": cache_requests == "delete",
|
| 759 |
-
}
|
| 760 |
-
|
| 761 |
-
return request_caching_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/evaluator_utils.py
DELETED
|
@@ -1,554 +0,0 @@
|
|
| 1 |
-
import collections
|
| 2 |
-
import logging
|
| 3 |
-
import math
|
| 4 |
-
import pathlib
|
| 5 |
-
import sys
|
| 6 |
-
from typing import List, Optional, Tuple, Union
|
| 7 |
-
|
| 8 |
-
from lm_eval.api.group import ConfigurableGroup
|
| 9 |
-
from lm_eval.api.metrics import (
|
| 10 |
-
aggregate_subtask_metrics,
|
| 11 |
-
mean,
|
| 12 |
-
pooled_sample_stderr,
|
| 13 |
-
stderr_for_metric,
|
| 14 |
-
)
|
| 15 |
-
from lm_eval.api.task import Task
|
| 16 |
-
from lm_eval.utils import positional_deprecated
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
eval_logger = logging.getLogger(__name__)
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
class TaskOutput:
|
| 23 |
-
"""
|
| 24 |
-
Wrapper class for Task outputs.It contains various attributes and methods to manage and calculate metrics for the task.
|
| 25 |
-
|
| 26 |
-
Attributes:
|
| 27 |
-
task (object): The task object.
|
| 28 |
-
task_name (str): The name of the task.
|
| 29 |
-
task_config (dict): The configuration of the task.
|
| 30 |
-
version (str): The version of the task.
|
| 31 |
-
group_name (str): The name of the task group.
|
| 32 |
-
n_shot (int): The number of shots for the task.
|
| 33 |
-
task_alias (str): The alias of the task.
|
| 34 |
-
group_alias (str): The alias of the task group.
|
| 35 |
-
is_group (bool): Indicates if the task is a group.
|
| 36 |
-
logged_samples (list): The list of logged samples.
|
| 37 |
-
sample_len (int): The length of the samples.
|
| 38 |
-
sample_metrics (defaultdict): The dictionary of samples' metrics.
|
| 39 |
-
agg_metrics (defaultdict): The dictionary of aggregate metrics.
|
| 40 |
-
|
| 41 |
-
Methods:
|
| 42 |
-
from_taskdict(cls, task_name: str, task):
|
| 43 |
-
Creates a TaskOutput instance from a task dictionary.
|
| 44 |
-
|
| 45 |
-
calculate_aggregate_metric(bootstrap_iters=100000) -> None:
|
| 46 |
-
Calculates the aggregate metrics for the task.
|
| 47 |
-
"""
|
| 48 |
-
|
| 49 |
-
def __init__(
|
| 50 |
-
self,
|
| 51 |
-
task=None,
|
| 52 |
-
task_name=None,
|
| 53 |
-
task_config=None,
|
| 54 |
-
version=None,
|
| 55 |
-
group_name=None,
|
| 56 |
-
n_shot=None,
|
| 57 |
-
task_alias=None,
|
| 58 |
-
group_alias=None,
|
| 59 |
-
is_group=None,
|
| 60 |
-
):
|
| 61 |
-
self.task = task
|
| 62 |
-
self.task_config = task_config
|
| 63 |
-
self.task_name = task_name
|
| 64 |
-
self.group_name = group_name
|
| 65 |
-
self.version = version
|
| 66 |
-
self.n_shot = n_shot
|
| 67 |
-
self.task_alias = task_alias
|
| 68 |
-
self.group_alias = group_alias
|
| 69 |
-
self.is_group = is_group
|
| 70 |
-
self.logged_samples = []
|
| 71 |
-
self.sample_len = None
|
| 72 |
-
self.sample_metrics = collections.defaultdict(list)
|
| 73 |
-
self.agg_metrics = collections.defaultdict(list)
|
| 74 |
-
|
| 75 |
-
@classmethod
|
| 76 |
-
def from_taskdict(cls, task_name: str, task):
|
| 77 |
-
if isinstance(task, tuple):
|
| 78 |
-
group_name, task = task
|
| 79 |
-
else:
|
| 80 |
-
group_name = None
|
| 81 |
-
if not task:
|
| 82 |
-
# these gets filtered out in get_task_list
|
| 83 |
-
# once they are added to group hierarchy
|
| 84 |
-
is_group = True
|
| 85 |
-
return cls(
|
| 86 |
-
task=task, task_name=task_name, is_group=is_group, group_name=group_name
|
| 87 |
-
)
|
| 88 |
-
version = task.VERSION
|
| 89 |
-
task_config = dict(task.dump_config())
|
| 90 |
-
if (n_shot := task_config.get("num_fewshot")) == 0:
|
| 91 |
-
n_shot = task_config.get("metadata", {}).get("num_fewshot", 0)
|
| 92 |
-
task_alias = task_config.get("alias")
|
| 93 |
-
group_alias = task_config.get("group_alias")
|
| 94 |
-
return cls(
|
| 95 |
-
task=task,
|
| 96 |
-
task_name=task_name,
|
| 97 |
-
task_config=task_config,
|
| 98 |
-
group_name=group_name,
|
| 99 |
-
version=version,
|
| 100 |
-
n_shot=n_shot,
|
| 101 |
-
task_alias=task_alias,
|
| 102 |
-
group_alias=group_alias,
|
| 103 |
-
)
|
| 104 |
-
|
| 105 |
-
def calculate_aggregate_metric(self, bootstrap_iters=100000) -> None:
|
| 106 |
-
for (metric, filter_key), items in self.sample_metrics.items():
|
| 107 |
-
try:
|
| 108 |
-
agg_fn = self.task.aggregation()[metric]
|
| 109 |
-
except KeyError:
|
| 110 |
-
# This is when process results output an arbitrary metric
|
| 111 |
-
# TODO: Handle this better and allow other aggregate functions other than mean.
|
| 112 |
-
agg_fn = mean
|
| 113 |
-
metric_key = f"{metric},{filter_key}"
|
| 114 |
-
self.agg_metrics[metric_key] = agg_fn(items)
|
| 115 |
-
self.sample_len = len(items) # TODO: same sample size for each metric?
|
| 116 |
-
if isinstance(bootstrap_iters, int):
|
| 117 |
-
stderr_fn = stderr_for_metric(
|
| 118 |
-
metric=agg_fn,
|
| 119 |
-
bootstrap_iters=min(bootstrap_iters, 100)
|
| 120 |
-
if metric in ["bleu", "chrf", "ter"]
|
| 121 |
-
else bootstrap_iters,
|
| 122 |
-
)
|
| 123 |
-
self.agg_metrics[f"{metric}_stderr,{filter_key}"] = (
|
| 124 |
-
stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
|
| 125 |
-
)
|
| 126 |
-
else:
|
| 127 |
-
raise ValueError(
|
| 128 |
-
f"Received bootstrap_iters '{bootstrap_iters}' but expected an integer. Set to 0 to turn off stderr calculations."
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
def __repr__(self):
|
| 132 |
-
return (
|
| 133 |
-
f"TaskOutput(task_name={self.task_name}, "
|
| 134 |
-
f"group_name={self.group_name}, "
|
| 135 |
-
f"version={self.version}, "
|
| 136 |
-
f"n_shot={self.n_shot}, "
|
| 137 |
-
f"task_alias={self.task_alias}, "
|
| 138 |
-
f"group_alias={self.group_alias})"
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def get_task_list(task_dict: dict) -> List[TaskOutput]:
|
| 143 |
-
outputs = []
|
| 144 |
-
for task_name, task_obj in task_dict.items():
|
| 145 |
-
if isinstance(task_obj, dict):
|
| 146 |
-
_outputs = get_task_list(task_obj)
|
| 147 |
-
outputs.extend(_outputs)
|
| 148 |
-
else:
|
| 149 |
-
task_output = TaskOutput.from_taskdict(task_name, task_obj)
|
| 150 |
-
outputs.append(task_output)
|
| 151 |
-
|
| 152 |
-
return outputs
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def get_subtask_list(task_dict, task_root=None, depth=0):
|
| 156 |
-
subtask_list = {}
|
| 157 |
-
for group_obj, task_obj in task_dict.items():
|
| 158 |
-
if isinstance(group_obj, ConfigurableGroup):
|
| 159 |
-
# group_name = group_obj.group_name
|
| 160 |
-
group_name = group_obj.group_name
|
| 161 |
-
else:
|
| 162 |
-
group_name = group_obj
|
| 163 |
-
if isinstance(task_obj, dict):
|
| 164 |
-
_subtask_list = get_subtask_list(
|
| 165 |
-
task_obj, task_root=group_name, depth=depth + 1
|
| 166 |
-
)
|
| 167 |
-
if task_root:
|
| 168 |
-
subtask_list.setdefault((task_root, depth), []).extend(
|
| 169 |
-
[
|
| 170 |
-
_task
|
| 171 |
-
for (_task, _depth) in _subtask_list.keys()
|
| 172 |
-
if (_depth - 1) == depth
|
| 173 |
-
]
|
| 174 |
-
)
|
| 175 |
-
|
| 176 |
-
subtask_list = {**subtask_list, **_subtask_list}
|
| 177 |
-
else:
|
| 178 |
-
if isinstance(task_obj, ConfigurableGroup):
|
| 179 |
-
# group_or_task_name = task_obj.group_name
|
| 180 |
-
group_or_task_name = task_obj.group_name
|
| 181 |
-
elif isinstance(task_obj, Task):
|
| 182 |
-
# group_or_task_name = task_obj.task_name
|
| 183 |
-
group_or_task_name = task_obj.task_name
|
| 184 |
-
|
| 185 |
-
if task_root is None:
|
| 186 |
-
subtask_list.setdefault((group_or_task_name, depth), [])
|
| 187 |
-
else:
|
| 188 |
-
subtask_list.setdefault((task_root, depth), []).append(
|
| 189 |
-
group_or_task_name
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
if depth == 0:
|
| 193 |
-
_subtask_list = {}
|
| 194 |
-
for group_key, task_list in subtask_list.items():
|
| 195 |
-
group_name, depth = group_key
|
| 196 |
-
_subtask_list[group_name] = task_list
|
| 197 |
-
subtask_list = _subtask_list
|
| 198 |
-
|
| 199 |
-
return subtask_list
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def print_writeout(task) -> None:
|
| 203 |
-
for inst in task.instances:
|
| 204 |
-
# print the prompt for the first few documents
|
| 205 |
-
if inst.doc_id < 1:
|
| 206 |
-
eval_logger.info(
|
| 207 |
-
f"Task: {task}; document {inst.doc_id}; context prompt (starting on next line):\
|
| 208 |
-
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
|
| 209 |
-
)
|
| 210 |
-
eval_logger.info(f"Request: {str(inst)}")
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
def get_sample_size(task, limit: Optional[int]) -> Union[int, None]:
|
| 214 |
-
if limit is not None:
|
| 215 |
-
limit = (
|
| 216 |
-
int(math.ceil(len(task.eval_docs) * limit)) if limit < 1.0 else int(limit)
|
| 217 |
-
)
|
| 218 |
-
return limit
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def prepare_print_tasks(
|
| 222 |
-
task_dict: dict,
|
| 223 |
-
results: dict,
|
| 224 |
-
task_depth=0,
|
| 225 |
-
group_depth=0,
|
| 226 |
-
) -> Tuple[dict, dict]:
|
| 227 |
-
"""
|
| 228 |
-
@param task_dict: Dictionary representing the group hierarchy of tasks. Each key is a group name and its
|
| 229 |
-
value is a list of task names.
|
| 230 |
-
@param results: Dictionary containing the results of each task. Each key is a
|
| 231 |
-
group name and its value is a dictionary of task results.
|
| 232 |
-
@param task_depth: The indentation level for printing the task
|
| 233 |
-
hierarchy. Default is 0.
|
| 234 |
-
@param group_depth: The indentation level for printing the group
|
| 235 |
-
hierarchy. Default is 0.
|
| 236 |
-
@return: A tuple of two dictionaries: results_agg and groups_agg. results_agg contains
|
| 237 |
-
aggregated results for each task, and groups_agg contains aggregated results for each group.
|
| 238 |
-
|
| 239 |
-
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
|
| 240 |
-
"""
|
| 241 |
-
|
| 242 |
-
def _sort_task_dict(task_dict):
|
| 243 |
-
"""
|
| 244 |
-
Helper utility. Sorts the task dict at the current level of the hierarchy based on alphabetized task name.
|
| 245 |
-
Required so that we end up sorting within each sub-header correctly.
|
| 246 |
-
"""
|
| 247 |
-
|
| 248 |
-
return dict(
|
| 249 |
-
sorted(
|
| 250 |
-
task_dict.items(),
|
| 251 |
-
key=lambda item: item[0].group_name
|
| 252 |
-
if isinstance(item[0], ConfigurableGroup)
|
| 253 |
-
else item[0],
|
| 254 |
-
)
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
task_agg = collections.defaultdict(dict)
|
| 258 |
-
group_agg = collections.defaultdict(dict)
|
| 259 |
-
task_dict = _sort_task_dict(task_dict)
|
| 260 |
-
for task_or_group_name, task_or_group_obj in task_dict.items():
|
| 261 |
-
tab_string = " " * task_depth + "- " if task_depth > 0 else ""
|
| 262 |
-
if isinstance(task_or_group_name, ConfigurableGroup):
|
| 263 |
-
# string_name = task_or_group_name.group_name
|
| 264 |
-
name = task_or_group_name.group_name
|
| 265 |
-
from_configurable_group = True
|
| 266 |
-
task_or_group_obj = _sort_task_dict(task_or_group_obj)
|
| 267 |
-
elif isinstance(task_or_group_name, str):
|
| 268 |
-
name = task_or_group_name
|
| 269 |
-
if isinstance(task_or_group_obj, Task):
|
| 270 |
-
# string_name = task_or_group_obj.task_name
|
| 271 |
-
name = task_or_group_obj.task_name
|
| 272 |
-
from_configurable_group = False
|
| 273 |
-
|
| 274 |
-
task_agg[name] = results[name].copy()
|
| 275 |
-
if from_configurable_group:
|
| 276 |
-
if task_or_group_name.group_alias is not None:
|
| 277 |
-
alias = task_or_group_name.group_alias
|
| 278 |
-
else:
|
| 279 |
-
alias = task_or_group_name.group
|
| 280 |
-
else:
|
| 281 |
-
if "alias" in task_agg[name]:
|
| 282 |
-
alias = task_agg[name]["alias"]
|
| 283 |
-
else:
|
| 284 |
-
alias = name
|
| 285 |
-
|
| 286 |
-
task_agg[name]["alias"] = tab_string + alias
|
| 287 |
-
if "samples" in task_agg[name]:
|
| 288 |
-
task_agg[name].pop("samples")
|
| 289 |
-
|
| 290 |
-
if from_configurable_group and (" " not in results[name]):
|
| 291 |
-
group_tab_string = " " * group_depth + "- " if group_depth > 0 else ""
|
| 292 |
-
group_agg[name] = results[name].copy()
|
| 293 |
-
group_agg[name]["alias"] = group_tab_string + alias
|
| 294 |
-
if "samples" in group_agg[name]:
|
| 295 |
-
group_agg[name].pop("samples")
|
| 296 |
-
|
| 297 |
-
if isinstance(task_or_group_obj, dict):
|
| 298 |
-
task_depth += 1
|
| 299 |
-
group_depth += 1
|
| 300 |
-
_task_agg, _group_agg = prepare_print_tasks(
|
| 301 |
-
task_or_group_obj, results, task_depth, group_depth
|
| 302 |
-
)
|
| 303 |
-
task_agg = {
|
| 304 |
-
**task_agg,
|
| 305 |
-
**_task_agg,
|
| 306 |
-
}
|
| 307 |
-
group_agg = {**group_agg, **_group_agg}
|
| 308 |
-
task_depth -= 1
|
| 309 |
-
group_depth -= 1
|
| 310 |
-
return task_agg, group_agg
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
def consolidate_results(
|
| 314 |
-
eval_tasks: List[TaskOutput],
|
| 315 |
-
) -> Tuple[dict, dict, dict, dict, dict, dict]:
|
| 316 |
-
"""
|
| 317 |
-
@param eval_tasks: list(TaskOutput).
|
| 318 |
-
@return: A tuple containing the consolidated results, samples, configs, versions, and num_fewshot.
|
| 319 |
-
|
| 320 |
-
Consolidates the results of multiple evaluation tasks into a single structure.
|
| 321 |
-
|
| 322 |
-
The method iterates over each evaluation instance and extracts relevant information to create the consolidated
|
| 323 |
-
results structure. The consolidated results structure has the following properties:
|
| 324 |
-
|
| 325 |
-
- results: A defaultdict with task names as keys and dictionaries as values. Each dictionary contains
|
| 326 |
-
metric/filter pairs as keys and corresponding metric values as values. The "alias" key is used to store task
|
| 327 |
-
aliases specified in the task configuration.
|
| 328 |
-
- samples: A defaultdict with task names as keys and lists of log samples as values.
|
| 329 |
-
- configs: A defaultdict with task names as keys and task configurations as values.
|
| 330 |
-
- versions: A defaultdict with task names as keys and task versions as values.
|
| 331 |
-
- num_fewshot: A defaultdict with task names as keys and number of few-shot samples as values.
|
| 332 |
-
- higher_is_better: A defaultdict with task names as keys and indicators of whether higher values are better
|
| 333 |
-
for each metric as values.
|
| 334 |
-
|
| 335 |
-
The method then returns the consolidated results, samples, configs, versions, and num_fewshot as a tuple.
|
| 336 |
-
"""
|
| 337 |
-
# stores the final result for each task, for each metric/filter pair.
|
| 338 |
-
results = collections.defaultdict(dict)
|
| 339 |
-
# logs info about each document evaluated.
|
| 340 |
-
samples = collections.defaultdict(list)
|
| 341 |
-
# store num-fewshot value per task
|
| 342 |
-
num_fewshot = collections.defaultdict(int)
|
| 343 |
-
# Tracks the YAML configs of all chosen task
|
| 344 |
-
configs = collections.defaultdict(dict)
|
| 345 |
-
# Tracks each task's version.
|
| 346 |
-
versions = collections.defaultdict(dict)
|
| 347 |
-
# Track `higher_is_better` for each metric
|
| 348 |
-
higher_is_better = collections.defaultdict(dict)
|
| 349 |
-
|
| 350 |
-
for task_output in eval_tasks:
|
| 351 |
-
if "task_alias" in (task_config := task_output.task_config):
|
| 352 |
-
results[task_output.task_name]["alias"] = task_config["task_alias"]
|
| 353 |
-
else:
|
| 354 |
-
results[task_output.task_name]["alias"] = task_output.task_name
|
| 355 |
-
if group_alias := task_output.group_alias:
|
| 356 |
-
if group_alias not in results and (group_name := task_output.group_name):
|
| 357 |
-
results[group_name]["alias"] = group_alias
|
| 358 |
-
num_fewshot[task_output.task_name] = task_output.n_shot
|
| 359 |
-
configs[task_output.task_name] = task_output.task_config
|
| 360 |
-
versions[task_output.task_name] = task_output.version
|
| 361 |
-
samples[task_output.task_name] = task_output.logged_samples
|
| 362 |
-
higher_is_better[task_output.task_name] = task_output.task.higher_is_better()
|
| 363 |
-
for (metric, filter_key), items in task_output.sample_metrics.items():
|
| 364 |
-
metric_key = f"{metric},{filter_key}"
|
| 365 |
-
results[task_output.task_name][metric_key] = task_output.agg_metrics[
|
| 366 |
-
metric_key
|
| 367 |
-
]
|
| 368 |
-
results[task_output.task_name]["samples"] = task_output.sample_len
|
| 369 |
-
results[task_output.task_name][f"{metric}_stderr,{filter_key}"] = (
|
| 370 |
-
task_output.agg_metrics[f"{metric}_stderr,{filter_key}"]
|
| 371 |
-
)
|
| 372 |
-
return results, samples, configs, versions, num_fewshot, higher_is_better
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
def consolidate_group_results(
|
| 376 |
-
results,
|
| 377 |
-
versions,
|
| 378 |
-
task_dict,
|
| 379 |
-
task_root=None,
|
| 380 |
-
show_group_table=False,
|
| 381 |
-
task_aggregation_list=None,
|
| 382 |
-
) -> Tuple[dict, dict, bool, Union[None,]]:
|
| 383 |
-
"""
|
| 384 |
-
(Recursively) calculates groups' aggregated metrics and updates the results and versions dictionaries with this info.
|
| 385 |
-
|
| 386 |
-
@return: a tuple [results, versions, show_group_table, task_aggregation_list] with formats described below:
|
| 387 |
-
|
| 388 |
-
- results: A defaultdict with task names (and, after this function is called, group names of
|
| 389 |
-
groups that perform aggregation) as keys, and dictionaries with "alias" and metric,filter_name pairs as keys.
|
| 390 |
-
- versions: A defaultdict with task names (and, after this function is called, group names of
|
| 391 |
-
groups that perform aggregation) as keys, and float values representing the task or group's version if a version is specified. (defaulting to None).
|
| 392 |
-
- show_group_table: a boolean which is true if there exists a group that requires printing of its aggregated scores in a group table.
|
| 393 |
-
- task_aggregation_list: a defaultdict listing the subtasks to average over to produce a given group's end metric.
|
| 394 |
-
|
| 395 |
-
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
|
| 396 |
-
In the top-level invocation of this function, task_aggregation_list is ignored.
|
| 397 |
-
"""
|
| 398 |
-
if task_root is None:
|
| 399 |
-
task_root = {}
|
| 400 |
-
|
| 401 |
-
if task_aggregation_list is None:
|
| 402 |
-
task_aggregation_list = {}
|
| 403 |
-
|
| 404 |
-
for group_or_task, group_or_task_info in task_dict.items():
|
| 405 |
-
# Convert to string
|
| 406 |
-
if isinstance(group_or_task, ConfigurableGroup):
|
| 407 |
-
group_config = group_or_task.config
|
| 408 |
-
group_or_task = group_or_task.group_name
|
| 409 |
-
else:
|
| 410 |
-
group_config = None
|
| 411 |
-
|
| 412 |
-
if isinstance(group_or_task_info, Task):
|
| 413 |
-
if task_root:
|
| 414 |
-
task_aggregation_list.setdefault(task_root, []).append(
|
| 415 |
-
group_or_task_info.task_name
|
| 416 |
-
)
|
| 417 |
-
else:
|
| 418 |
-
(
|
| 419 |
-
results,
|
| 420 |
-
versions,
|
| 421 |
-
show_group_table,
|
| 422 |
-
_task_aggregation_list,
|
| 423 |
-
) = consolidate_group_results(
|
| 424 |
-
results,
|
| 425 |
-
versions,
|
| 426 |
-
group_or_task_info,
|
| 427 |
-
group_or_task,
|
| 428 |
-
show_group_table,
|
| 429 |
-
task_aggregation_list,
|
| 430 |
-
)
|
| 431 |
-
if task_root:
|
| 432 |
-
task_aggregation_list.setdefault(task_root, []).extend(
|
| 433 |
-
task_aggregation_list.get(group_or_task, [])
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
-
if (group_config is None) or (
|
| 437 |
-
group_config["aggregate_metric_list"] is None
|
| 438 |
-
):
|
| 439 |
-
results[group_or_task][" "] = " "
|
| 440 |
-
continue
|
| 441 |
-
|
| 442 |
-
if "aggregate_metric_list" in group_config:
|
| 443 |
-
agg_metric_list = group_config["aggregate_metric_list"]
|
| 444 |
-
|
| 445 |
-
show_group_table = show_group_table | bool(
|
| 446 |
-
group_config["aggregate_metric_list"]
|
| 447 |
-
)
|
| 448 |
-
|
| 449 |
-
task_list = _task_aggregation_list[group_or_task]
|
| 450 |
-
|
| 451 |
-
metric_list = list(
|
| 452 |
-
{
|
| 453 |
-
key
|
| 454 |
-
for task in task_list
|
| 455 |
-
for key in results[task].keys()
|
| 456 |
-
if "_stderr" not in key and key not in ["task", "alias", "samples"]
|
| 457 |
-
}
|
| 458 |
-
)
|
| 459 |
-
for metric in metric_list:
|
| 460 |
-
stderr = "_stderr,".join(metric.split(","))
|
| 461 |
-
|
| 462 |
-
# gather metrics, sizes, and stderrs from subtasks
|
| 463 |
-
metrics = [
|
| 464 |
-
results[task][metric]
|
| 465 |
-
for task in task_list
|
| 466 |
-
if metric in results[task]
|
| 467 |
-
] # TODO: copy?
|
| 468 |
-
stderrs = [
|
| 469 |
-
results[task][stderr]
|
| 470 |
-
for task in task_list
|
| 471 |
-
if stderr in results[task]
|
| 472 |
-
]
|
| 473 |
-
sizes = [
|
| 474 |
-
results[task]["samples"]
|
| 475 |
-
for task in task_list
|
| 476 |
-
if metric in results[task]
|
| 477 |
-
]
|
| 478 |
-
|
| 479 |
-
for metric_config in agg_metric_list:
|
| 480 |
-
for filter_name in metric_config["filter_list"]:
|
| 481 |
-
if metric != ",".join([metric_config["metric"], filter_name]):
|
| 482 |
-
continue
|
| 483 |
-
|
| 484 |
-
# compute group's pooled metric and stderr
|
| 485 |
-
if metric_config["aggregation"] == "mean":
|
| 486 |
-
aggregate_fn = aggregate_subtask_metrics
|
| 487 |
-
elif callable(metric_config["aggregation"]):
|
| 488 |
-
aggregate_fn = metric_config["aggregation"]
|
| 489 |
-
else:
|
| 490 |
-
raise ValueError(
|
| 491 |
-
f"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '{metric_config['aggregation']}' for group '{group_or_task}'"
|
| 492 |
-
)
|
| 493 |
-
|
| 494 |
-
results[group_or_task][metric] = aggregate_fn(
|
| 495 |
-
metrics,
|
| 496 |
-
sizes,
|
| 497 |
-
metric_config["weight_by_size"],
|
| 498 |
-
)
|
| 499 |
-
# TODO: calculate groups' metrics using arbitrary agg fns
|
| 500 |
-
if "N/A" in stderrs:
|
| 501 |
-
results[group_or_task][stderr] = "N/A"
|
| 502 |
-
else:
|
| 503 |
-
# NOTE: this assumes we are using the mean to aggregate. There are warnings about this elsewhere
|
| 504 |
-
results[group_or_task][stderr] = pooled_sample_stderr(
|
| 505 |
-
stderrs, sizes
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
-
results[group_or_task]["samples"] = sum(sizes)
|
| 509 |
-
group_metadata = group_config.get("metadata", None)
|
| 510 |
-
if group_metadata is not None:
|
| 511 |
-
versions[group_or_task] = group_metadata.get("version", None)
|
| 512 |
-
# print(results)
|
| 513 |
-
return results, versions, show_group_table, task_aggregation_list
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
@positional_deprecated
|
| 517 |
-
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
|
| 518 |
-
"""
|
| 519 |
-
Search upward in the directory tree to a maximum of three layers
|
| 520 |
-
to find and return the package root (containing the 'tests' folder)
|
| 521 |
-
"""
|
| 522 |
-
cur_path = start_path.resolve()
|
| 523 |
-
max_layers = 3
|
| 524 |
-
for _ in range(max_layers):
|
| 525 |
-
if (cur_path / "tests" / "test_version_stable.py").exists():
|
| 526 |
-
return cur_path
|
| 527 |
-
else:
|
| 528 |
-
cur_path = cur_path.parent.resolve()
|
| 529 |
-
raise FileNotFoundError(
|
| 530 |
-
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
|
| 531 |
-
)
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
@positional_deprecated
|
| 535 |
-
def run_task_tests(task_list: List[str]):
|
| 536 |
-
"""
|
| 537 |
-
Find the package root and run the tests for the given tasks
|
| 538 |
-
"""
|
| 539 |
-
import pytest
|
| 540 |
-
|
| 541 |
-
package_root = find_test_root(start_path=pathlib.Path(__file__))
|
| 542 |
-
task_string = " or ".join(task_list)
|
| 543 |
-
args = [
|
| 544 |
-
f"{package_root}/tests/test_version_stable.py",
|
| 545 |
-
f"--rootdir={package_root}",
|
| 546 |
-
"-k",
|
| 547 |
-
f"{task_string}",
|
| 548 |
-
]
|
| 549 |
-
sys.path.append(str(package_root))
|
| 550 |
-
pytest_return_val = pytest.main(args)
|
| 551 |
-
if pytest_return_val:
|
| 552 |
-
raise ValueError(
|
| 553 |
-
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
|
| 554 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/filters/__init__.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
from functools import partial
|
| 2 |
-
from typing import List
|
| 3 |
-
|
| 4 |
-
from lm_eval.api.filter import FilterEnsemble
|
| 5 |
-
from lm_eval.api.registry import get_filter
|
| 6 |
-
|
| 7 |
-
from . import custom, extraction, selection, transformation
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def build_filter_ensemble(
|
| 11 |
-
filter_name: str, components: List[List[str]]
|
| 12 |
-
) -> FilterEnsemble:
|
| 13 |
-
"""
|
| 14 |
-
Create a filtering pipeline.
|
| 15 |
-
"""
|
| 16 |
-
filters = []
|
| 17 |
-
for function, kwargs in components:
|
| 18 |
-
if kwargs is None:
|
| 19 |
-
kwargs = {}
|
| 20 |
-
# create a filter given its name in the registry
|
| 21 |
-
f = partial(get_filter(function), **kwargs)
|
| 22 |
-
# add the filter as a pipeline step
|
| 23 |
-
filters.append(f)
|
| 24 |
-
|
| 25 |
-
return FilterEnsemble(name=filter_name, filters=filters)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/filters/custom.py
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
from lm_eval.api.filter import Filter
|
| 2 |
-
from lm_eval.api.registry import register_filter
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
@register_filter("custom")
|
| 6 |
-
class CustomFilter(Filter):
|
| 7 |
-
"""
|
| 8 |
-
Custom filter that applies a custom, user-defined function to the model responses.
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
def __init__(self, **kwargs) -> None:
|
| 12 |
-
self.filter_fn = kwargs.pop("filter_fn")
|
| 13 |
-
|
| 14 |
-
super().__init__(**kwargs)
|
| 15 |
-
|
| 16 |
-
def apply(self, resps, docs):
|
| 17 |
-
return self.filter_fn(resps, docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/filters/decontamination.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
from lm_eval.api.filter import Filter
|
| 2 |
-
from lm_eval.api.registry import register_filter
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
@register_filter("decontaminate")
|
| 6 |
-
class DecontaminationFilter(Filter):
|
| 7 |
-
"""
|
| 8 |
-
A filter which evaluates
|
| 9 |
-
"""
|
| 10 |
-
|
| 11 |
-
name = "track_decontamination"
|
| 12 |
-
|
| 13 |
-
def __init__(self, path) -> None:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
|
| 17 |
-
should further cache result on a given (task_name, doc_id)
|
| 18 |
-
"""
|
| 19 |
-
self._decontam_results = None
|
| 20 |
-
|
| 21 |
-
def apply(self, resps, docs) -> None:
|
| 22 |
-
"""
|
| 23 |
-
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
|
| 24 |
-
"""
|
| 25 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/filters/extraction.py
DELETED
|
@@ -1,233 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import sys
|
| 3 |
-
import unicodedata
|
| 4 |
-
|
| 5 |
-
from lm_eval.api.filter import Filter
|
| 6 |
-
from lm_eval.api.registry import register_filter
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
@register_filter("regex")
|
| 10 |
-
class RegexFilter(Filter):
|
| 11 |
-
"""A filter that extracts values from text using regex pattern matching.
|
| 12 |
-
|
| 13 |
-
This filter applies a regex pattern to each model response and extracts matched values.
|
| 14 |
-
If no match is found, returns a fallback value. Useful for extracting structured data
|
| 15 |
-
(like numbers) from unstructured model outputs.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 21 |
-
group_select: int = 0,
|
| 22 |
-
fallback: str = "[invalid]",
|
| 23 |
-
) -> None:
|
| 24 |
-
"""
|
| 25 |
-
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 26 |
-
`fallback` defines the output returned if no matches for the regex are located.
|
| 27 |
-
"""
|
| 28 |
-
self.regex_pattern = regex_pattern
|
| 29 |
-
self.regex = re.compile(regex_pattern)
|
| 30 |
-
self.group_select = group_select
|
| 31 |
-
self.fallback = fallback
|
| 32 |
-
|
| 33 |
-
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 34 |
-
# here, we assume we have a list, in which each element is
|
| 35 |
-
# a list of model responses for some particular input/target pair.
|
| 36 |
-
# so we process each of these (same input/target response sets)
|
| 37 |
-
# independently (and keep them a list.)
|
| 38 |
-
def filter_set(inst):
|
| 39 |
-
filtered = []
|
| 40 |
-
for resp in inst:
|
| 41 |
-
match = self.regex.findall(resp)
|
| 42 |
-
if match:
|
| 43 |
-
match = match[self.group_select]
|
| 44 |
-
if isinstance(match, tuple):
|
| 45 |
-
match = [m for m in match if m]
|
| 46 |
-
if match:
|
| 47 |
-
match = match[0]
|
| 48 |
-
else:
|
| 49 |
-
match = self.fallback
|
| 50 |
-
match = match.strip()
|
| 51 |
-
else:
|
| 52 |
-
match = self.fallback
|
| 53 |
-
filtered.append(match)
|
| 54 |
-
return filtered
|
| 55 |
-
|
| 56 |
-
filtered_resps = list(map(lambda x: filter_set(x), resps))
|
| 57 |
-
return filtered_resps
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
@register_filter("regex_pos")
|
| 61 |
-
class POSFilter(Filter):
|
| 62 |
-
""" """
|
| 63 |
-
|
| 64 |
-
def __init__(
|
| 65 |
-
self,
|
| 66 |
-
regex_pattern: str = r"\['(.*?)'\]",
|
| 67 |
-
group_select=0,
|
| 68 |
-
fallback=None,
|
| 69 |
-
) -> None:
|
| 70 |
-
"""
|
| 71 |
-
pass a string `regex` to run `re.compile(r"regex")` on.
|
| 72 |
-
`fallback` defines the output returned if no matches for the regex are located.
|
| 73 |
-
"""
|
| 74 |
-
if fallback is None:
|
| 75 |
-
fallback = ["invalid"]
|
| 76 |
-
self.regex_pattern = regex_pattern
|
| 77 |
-
self.regex = re.compile(regex_pattern)
|
| 78 |
-
self.group_select = group_select
|
| 79 |
-
self.fallback = fallback
|
| 80 |
-
|
| 81 |
-
def apply(self, resps, docs):
|
| 82 |
-
def extract_tagged_tokens(text):
|
| 83 |
-
# Extract tagged tokens list from text input using regex
|
| 84 |
-
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
|
| 85 |
-
return [(token, pos) for token, pos in tokens]
|
| 86 |
-
|
| 87 |
-
def extract_pos_tags(result):
|
| 88 |
-
pos_tags = []
|
| 89 |
-
if isinstance(result, str):
|
| 90 |
-
result = extract_tagged_tokens(result)
|
| 91 |
-
pos_tags.extend(pos for _, pos in result)
|
| 92 |
-
return pos_tags if pos_tags else self.fallback
|
| 93 |
-
|
| 94 |
-
def filter_set(inst):
|
| 95 |
-
filtered = []
|
| 96 |
-
for resp in inst:
|
| 97 |
-
match = extract_pos_tags(resp)
|
| 98 |
-
filtered.append(match)
|
| 99 |
-
return filtered
|
| 100 |
-
|
| 101 |
-
filtered_resps = map(lambda x: filter_set(x), resps)
|
| 102 |
-
|
| 103 |
-
return filtered_resps
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
@register_filter("remove_whitespace")
|
| 107 |
-
class WhitespaceFilter(Filter):
|
| 108 |
-
"""Filters out leading whitespace from responses."""
|
| 109 |
-
|
| 110 |
-
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 111 |
-
def filter_set(inst):
|
| 112 |
-
filtered_resp = []
|
| 113 |
-
for resp in inst:
|
| 114 |
-
resp = resp.lstrip()
|
| 115 |
-
filtered_resp.append(resp)
|
| 116 |
-
return filtered_resp
|
| 117 |
-
|
| 118 |
-
filtered_resps = [filter_set(resp) for resp in resps]
|
| 119 |
-
|
| 120 |
-
return filtered_resps
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
@register_filter("multi_choice_regex")
|
| 124 |
-
class MultiChoiceRegexFilter(RegexFilter):
|
| 125 |
-
"""
|
| 126 |
-
A filter used to extract a model's answer on multiple choice questions with
|
| 127 |
-
letter answers. assumes each document has a "choices" field
|
| 128 |
-
containing the list of answer choices and that the answer label symbols
|
| 129 |
-
are of the form (A), (B), (C), ... or A, B, C.
|
| 130 |
-
"""
|
| 131 |
-
|
| 132 |
-
def __init__(
|
| 133 |
-
self,
|
| 134 |
-
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
|
| 135 |
-
group_select=0,
|
| 136 |
-
fallback: str = "[invalid]",
|
| 137 |
-
ignore_case=False,
|
| 138 |
-
ignore_punctuation=False,
|
| 139 |
-
regexes_to_ignore=None,
|
| 140 |
-
) -> None:
|
| 141 |
-
"""
|
| 142 |
-
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
|
| 143 |
-
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
|
| 144 |
-
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
|
| 145 |
-
group_select: Selects the (group_select)th match from the findall result.
|
| 146 |
-
ignore_case: Ignores the case during step 1 matching
|
| 147 |
-
ignore_punctuation: Remove the punctuation during step 1 matching
|
| 148 |
-
regexes_to_ignore: Remove these regexes during step 1 matching
|
| 149 |
-
"""
|
| 150 |
-
super().__init__(regex_pattern, group_select, fallback)
|
| 151 |
-
self.ignore_case = ignore_case
|
| 152 |
-
self.ignore_punctuation = ignore_punctuation
|
| 153 |
-
self.regexes_to_ignore = regexes_to_ignore
|
| 154 |
-
|
| 155 |
-
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
|
| 156 |
-
# here, we assume we have a list, in which each element is
|
| 157 |
-
# a list of model responses for some particular input/target pair.
|
| 158 |
-
# so we process each of these (same input/target response sets)
|
| 159 |
-
# independently (and keep them a list.)
|
| 160 |
-
|
| 161 |
-
def find_match(regex, resp, convert_dict={}):
|
| 162 |
-
match = regex.findall(resp)
|
| 163 |
-
if match:
|
| 164 |
-
match = match[self.group_select]
|
| 165 |
-
if isinstance(match, tuple):
|
| 166 |
-
match = [m for m in match if m][0]
|
| 167 |
-
match = match.strip()
|
| 168 |
-
if match and match in convert_dict:
|
| 169 |
-
match = convert_dict[match]
|
| 170 |
-
return match
|
| 171 |
-
|
| 172 |
-
punct_tbl = dict.fromkeys(
|
| 173 |
-
i
|
| 174 |
-
for i in range(sys.maxunicode)
|
| 175 |
-
if unicodedata.category(chr(i)).startswith("P")
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
def filter_ignores(st):
|
| 179 |
-
if self.regexes_to_ignore is not None:
|
| 180 |
-
for s in self.regexes_to_ignore:
|
| 181 |
-
st = re.sub(s, "", st)
|
| 182 |
-
|
| 183 |
-
if self.ignore_case:
|
| 184 |
-
st = st.lower()
|
| 185 |
-
|
| 186 |
-
if self.ignore_punctuation:
|
| 187 |
-
# https://stackoverflow.com/a/266162
|
| 188 |
-
st = st.translate(punct_tbl)
|
| 189 |
-
return st
|
| 190 |
-
|
| 191 |
-
filtered_resps = []
|
| 192 |
-
|
| 193 |
-
for r, doc in zip(resps, docs):
|
| 194 |
-
fallback_regexes = []
|
| 195 |
-
choice_to_alpha = {}
|
| 196 |
-
next_alpha = "A"
|
| 197 |
-
|
| 198 |
-
without_paren_fallback_regexes = []
|
| 199 |
-
without_paren_to_target = {}
|
| 200 |
-
|
| 201 |
-
choices = doc["choices"]
|
| 202 |
-
for c in choices:
|
| 203 |
-
m = filter_ignores(c.strip())
|
| 204 |
-
fallback_regexes.append(f"{re.escape(m)}")
|
| 205 |
-
choice_to_alpha[m] = f"({next_alpha})"
|
| 206 |
-
|
| 207 |
-
without_paren_fallback_regexes.append(next_alpha)
|
| 208 |
-
without_paren_to_target[next_alpha] = f"({next_alpha})"
|
| 209 |
-
|
| 210 |
-
next_alpha = chr(ord(next_alpha) + 1)
|
| 211 |
-
fallback_regex = re.compile("|".join(fallback_regexes))
|
| 212 |
-
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
|
| 213 |
-
without_paren_fallback_regex = re.compile(
|
| 214 |
-
rf":[\s]*({without_paren_fallback_regex})"
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
filtered = []
|
| 218 |
-
for resp in r:
|
| 219 |
-
match = find_match(self.regex, resp)
|
| 220 |
-
if not match:
|
| 221 |
-
match = find_match(
|
| 222 |
-
fallback_regex, filter_ignores(resp), choice_to_alpha
|
| 223 |
-
)
|
| 224 |
-
if not match:
|
| 225 |
-
match = find_match(
|
| 226 |
-
without_paren_fallback_regex, resp, without_paren_to_target
|
| 227 |
-
)
|
| 228 |
-
if not match:
|
| 229 |
-
match = self.fallback
|
| 230 |
-
filtered.append(match)
|
| 231 |
-
filtered_resps.append(filtered)
|
| 232 |
-
|
| 233 |
-
return filtered_resps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/filters/selection.py
DELETED
|
@@ -1,61 +0,0 @@
|
|
| 1 |
-
from collections import Counter
|
| 2 |
-
|
| 3 |
-
from lm_eval.api.filter import Filter
|
| 4 |
-
from lm_eval.api.registry import register_filter
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
|
| 8 |
-
# that takes an input and returns a scalar and then should select the max reward,
|
| 9 |
-
# or should implement different filters for different ways of handling a reward model's inference.
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
@register_filter("take_first")
|
| 13 |
-
class TakeFirstFilter(Filter):
|
| 14 |
-
def __init__(self) -> None:
|
| 15 |
-
"""
|
| 16 |
-
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
def apply(self, resps, docs):
|
| 20 |
-
"""
|
| 21 |
-
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
|
| 22 |
-
"""
|
| 23 |
-
return map(lambda r: r[0], resps)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@register_filter("take_first_k")
|
| 27 |
-
class TakeKFilter(Filter):
|
| 28 |
-
def __init__(self, **kwargs) -> None:
|
| 29 |
-
self.k = kwargs.pop("k")
|
| 30 |
-
|
| 31 |
-
super().__init__(**kwargs)
|
| 32 |
-
|
| 33 |
-
def apply(self, resps, docs):
|
| 34 |
-
# need resp to be subscriptable to check below
|
| 35 |
-
resps = list(resps)
|
| 36 |
-
# check we have at least k responses per doc, else we can't take the first k
|
| 37 |
-
assert len(resps[0]) >= self.k, (
|
| 38 |
-
f"Need at least {self.k} responses per doc to take first {self.k}, but got {len(resps[0])} only! Please increase TaskConfig.repeats ."
|
| 39 |
-
)
|
| 40 |
-
return map(lambda r: r[: self.k], resps)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
@register_filter("majority_vote")
|
| 44 |
-
class MajorityVoteFilter(Filter):
|
| 45 |
-
def __init__(self) -> None:
|
| 46 |
-
"""
|
| 47 |
-
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
def apply(self, resps, docs):
|
| 51 |
-
"""
|
| 52 |
-
Each entry of `resps` is a list of model responses.
|
| 53 |
-
We select the response that occurs most frequently in each entry of `resps`.
|
| 54 |
-
"""
|
| 55 |
-
|
| 56 |
-
def select_majority(resp):
|
| 57 |
-
counts = Counter(resp)
|
| 58 |
-
vote = counts.most_common(1)[0][0]
|
| 59 |
-
return vote
|
| 60 |
-
|
| 61 |
-
return map(lambda r: [select_majority(r)], resps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/filters/transformation.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
from lm_eval.api.filter import Filter
|
| 4 |
-
from lm_eval.api.registry import register_filter
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
@register_filter("lowercase")
|
| 8 |
-
class LowercaseFilter(Filter):
|
| 9 |
-
def __init__(self) -> None:
|
| 10 |
-
pass
|
| 11 |
-
|
| 12 |
-
def apply(self, resps, docs):
|
| 13 |
-
def filter_set(inst):
|
| 14 |
-
return [resp.lower() for resp in inst]
|
| 15 |
-
|
| 16 |
-
return [filter_set(resp) for resp in resps]
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
@register_filter("uppercase")
|
| 20 |
-
class UppercaseFilter(Filter):
|
| 21 |
-
def __init__(self) -> None:
|
| 22 |
-
pass
|
| 23 |
-
|
| 24 |
-
def apply(self, resps, docs):
|
| 25 |
-
def filter_set(inst):
|
| 26 |
-
return [resp.upper() for resp in inst]
|
| 27 |
-
|
| 28 |
-
return [filter_set(resp) for resp in resps]
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@register_filter("map")
|
| 32 |
-
class MapFilter(Filter):
|
| 33 |
-
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
|
| 34 |
-
"""
|
| 35 |
-
Initializes the MapFilter with a given mapping dictionary and default value.
|
| 36 |
-
|
| 37 |
-
Args:
|
| 38 |
-
- mapping_dict (dict): A dictionary containing the key-value mappings.
|
| 39 |
-
Default is an empty dictionary.
|
| 40 |
-
- default_value (Any): The value to be returned when a key is not found in the mapping_dict.
|
| 41 |
-
Default is None.
|
| 42 |
-
|
| 43 |
-
Example:
|
| 44 |
-
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
|
| 45 |
-
"""
|
| 46 |
-
if mapping_dict is None:
|
| 47 |
-
mapping_dict = {}
|
| 48 |
-
assert isinstance(mapping_dict, dict), (
|
| 49 |
-
"Provided mapping_dict is not a dictionary"
|
| 50 |
-
)
|
| 51 |
-
self.mapping_dict = mapping_dict
|
| 52 |
-
self.default_value = default_value
|
| 53 |
-
|
| 54 |
-
def apply(self, resps, docs):
|
| 55 |
-
def filter_set(inst):
|
| 56 |
-
return [self.mapping_dict.get(resp, self.default_value) for resp in inst]
|
| 57 |
-
|
| 58 |
-
return [filter_set(resp) for resp in resps]
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
@register_filter("format_span")
|
| 62 |
-
class SPANFilter(Filter):
|
| 63 |
-
def __init__(self) -> None:
|
| 64 |
-
pass
|
| 65 |
-
|
| 66 |
-
def apply(self, resps, docs):
|
| 67 |
-
def format_ner_text(text):
|
| 68 |
-
label_dict = {
|
| 69 |
-
"person": "PER",
|
| 70 |
-
"location": "LOC",
|
| 71 |
-
"organization": "ORG",
|
| 72 |
-
"counties": "LOC",
|
| 73 |
-
"places": "LOC",
|
| 74 |
-
"people": "PER",
|
| 75 |
-
"persons": "PER",
|
| 76 |
-
"company": "ORG",
|
| 77 |
-
"country": "LOC",
|
| 78 |
-
"continent": "LOC",
|
| 79 |
-
"time": "DATE",
|
| 80 |
-
"date": "DATE",
|
| 81 |
-
"per": "PER",
|
| 82 |
-
"loc": "LOC",
|
| 83 |
-
"org": "ORG",
|
| 84 |
-
}
|
| 85 |
-
text = text.lower()
|
| 86 |
-
for key, value in label_dict.items():
|
| 87 |
-
text = text.replace(key, value)
|
| 88 |
-
|
| 89 |
-
text = "$".join(i for i in text.split("$$"))
|
| 90 |
-
return text.rstrip("$$")
|
| 91 |
-
|
| 92 |
-
def format_named_entities(text):
|
| 93 |
-
"""
|
| 94 |
-
Extract named entities from text and format them as 'label: value $$ label: value'.
|
| 95 |
-
Handles grouped entities (e.g., LOC: kenya, uganda) and excludes 'none' values.
|
| 96 |
-
"""
|
| 97 |
-
# Regular expression to match label: entities pattern
|
| 98 |
-
pattern = r"\b(PER|LOC|ORG|DATE):\s*([^$]+)"
|
| 99 |
-
# Normalize newline characters
|
| 100 |
-
text = text.replace("\n", "$").strip()
|
| 101 |
-
matches = re.findall(pattern, text)
|
| 102 |
-
|
| 103 |
-
formatted_entities = []
|
| 104 |
-
|
| 105 |
-
for label, values in matches:
|
| 106 |
-
# Split multiple entities separated by commas and strip whitespace
|
| 107 |
-
entities = [value.strip() for value in values.split(",")]
|
| 108 |
-
|
| 109 |
-
# Exclude 'none' entities
|
| 110 |
-
for entity in entities:
|
| 111 |
-
if entity.lower() != "none":
|
| 112 |
-
formatted_entities.append(f"{label.lower()}: {entity}")
|
| 113 |
-
|
| 114 |
-
# Join entities with the desired separator
|
| 115 |
-
return " $ ".join(formatted_entities)
|
| 116 |
-
|
| 117 |
-
def filter_set(inst):
|
| 118 |
-
return [
|
| 119 |
-
format_named_entities(format_ner_text(resp.lower())) for resp in inst
|
| 120 |
-
]
|
| 121 |
-
|
| 122 |
-
return [filter_set(resp) for resp in resps]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/loggers/__init__.py
DELETED
|
@@ -1,2 +0,0 @@
|
|
| 1 |
-
from .evaluation_tracker import EvaluationTracker
|
| 2 |
-
from .wandb_logger import WandbLogger
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/loggers/evaluation_tracker.py
DELETED
|
@@ -1,537 +0,0 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import logging
|
| 3 |
-
import os
|
| 4 |
-
import re
|
| 5 |
-
import time
|
| 6 |
-
from collections import defaultdict
|
| 7 |
-
from dataclasses import asdict, dataclass
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
|
| 11 |
-
from datasets import load_dataset
|
| 12 |
-
from datasets.utils.metadata import MetadataConfigs
|
| 13 |
-
from huggingface_hub import (
|
| 14 |
-
DatasetCard,
|
| 15 |
-
DatasetCardData,
|
| 16 |
-
HfApi,
|
| 17 |
-
hf_hub_url,
|
| 18 |
-
)
|
| 19 |
-
from huggingface_hub.utils import build_hf_headers, get_session, hf_raise_for_status
|
| 20 |
-
|
| 21 |
-
from lm_eval.utils import (
|
| 22 |
-
get_file_datetime,
|
| 23 |
-
get_file_task_name,
|
| 24 |
-
get_results_filenames,
|
| 25 |
-
get_sample_results_filenames,
|
| 26 |
-
handle_non_serializable,
|
| 27 |
-
hash_string,
|
| 28 |
-
sanitize_list,
|
| 29 |
-
sanitize_model_name,
|
| 30 |
-
sanitize_task_name,
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
eval_logger = logging.getLogger(__name__)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
@dataclass(init=False)
|
| 38 |
-
class GeneralConfigTracker:
|
| 39 |
-
"""
|
| 40 |
-
Tracker for the evaluation parameters.
|
| 41 |
-
|
| 42 |
-
Attributes:
|
| 43 |
-
model_source (str): Source of the model (e.g. Hugging Face, GGUF, etc.)
|
| 44 |
-
model_name (str): Name of the model.
|
| 45 |
-
model_name_sanitized (str): Sanitized model name for directory creation.
|
| 46 |
-
start_time (float): Start time of the experiment. Logged at class init.
|
| 47 |
-
end_time (float): Start time of the experiment. Logged when calling [`GeneralConfigTracker.log_end_time`]
|
| 48 |
-
total_evaluation_time_seconds (str): Inferred total evaluation time in seconds (from the start and end times).
|
| 49 |
-
"""
|
| 50 |
-
|
| 51 |
-
model_source: str = None
|
| 52 |
-
model_name: str = None
|
| 53 |
-
model_name_sanitized: str = None
|
| 54 |
-
system_instruction: str = None
|
| 55 |
-
system_instruction_sha: str = None
|
| 56 |
-
fewshot_as_multiturn: bool = None
|
| 57 |
-
chat_template: str = None
|
| 58 |
-
chat_template_sha: str = None
|
| 59 |
-
start_time: float = None
|
| 60 |
-
end_time: float = None
|
| 61 |
-
total_evaluation_time_seconds: str = None
|
| 62 |
-
|
| 63 |
-
def __init__(self) -> None:
|
| 64 |
-
"""Starts the evaluation timer."""
|
| 65 |
-
self.start_time = time.perf_counter()
|
| 66 |
-
|
| 67 |
-
@staticmethod
|
| 68 |
-
def _get_model_name(model_args: str) -> str:
|
| 69 |
-
"""Extracts the model name from the model arguments."""
|
| 70 |
-
|
| 71 |
-
def extract_model_name(model_args: str, key: str) -> str:
|
| 72 |
-
"""Extracts the model name from the model arguments using a key."""
|
| 73 |
-
args_after_key = model_args.split(key)[1]
|
| 74 |
-
return args_after_key.split(",")[0]
|
| 75 |
-
|
| 76 |
-
# order does matter, e.g. peft and delta are provided together with pretrained
|
| 77 |
-
prefixes = ["peft=", "delta=", "pretrained=", "model=", "path=", "engine="]
|
| 78 |
-
for prefix in prefixes:
|
| 79 |
-
if prefix in model_args:
|
| 80 |
-
return extract_model_name(model_args, prefix)
|
| 81 |
-
return ""
|
| 82 |
-
|
| 83 |
-
def log_experiment_args(
|
| 84 |
-
self,
|
| 85 |
-
model_source: str,
|
| 86 |
-
model_args: str,
|
| 87 |
-
system_instruction: str,
|
| 88 |
-
chat_template: str,
|
| 89 |
-
fewshot_as_multiturn: bool,
|
| 90 |
-
) -> None:
|
| 91 |
-
"""Logs model parameters and job ID."""
|
| 92 |
-
self.model_source = model_source
|
| 93 |
-
self.model_name = GeneralConfigTracker._get_model_name(model_args)
|
| 94 |
-
self.model_name_sanitized = sanitize_model_name(self.model_name)
|
| 95 |
-
self.system_instruction = system_instruction
|
| 96 |
-
self.system_instruction_sha = (
|
| 97 |
-
hash_string(system_instruction) if system_instruction else None
|
| 98 |
-
)
|
| 99 |
-
self.chat_template = chat_template
|
| 100 |
-
self.chat_template_sha = hash_string(chat_template) if chat_template else None
|
| 101 |
-
self.fewshot_as_multiturn = fewshot_as_multiturn
|
| 102 |
-
|
| 103 |
-
def log_end_time(self) -> None:
|
| 104 |
-
"""Logs the end time of the evaluation and calculates the total evaluation time."""
|
| 105 |
-
self.end_time = time.perf_counter()
|
| 106 |
-
self.total_evaluation_time_seconds = str(self.end_time - self.start_time)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
class EvaluationTracker:
|
| 110 |
-
"""
|
| 111 |
-
Keeps track and saves relevant information of the evaluation process.
|
| 112 |
-
Compiles the data from trackers and writes it to files, which can be published to the Hugging Face hub if requested.
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
def __init__(
|
| 116 |
-
self,
|
| 117 |
-
output_path: str = None,
|
| 118 |
-
hub_results_org: str = "",
|
| 119 |
-
hub_repo_name: str = "",
|
| 120 |
-
details_repo_name: str = "",
|
| 121 |
-
results_repo_name: str = "",
|
| 122 |
-
push_results_to_hub: bool = False,
|
| 123 |
-
push_samples_to_hub: bool = False,
|
| 124 |
-
public_repo: bool = False,
|
| 125 |
-
token: str = "",
|
| 126 |
-
leaderboard_url: str = "",
|
| 127 |
-
point_of_contact: str = "",
|
| 128 |
-
gated: bool = False,
|
| 129 |
-
) -> None:
|
| 130 |
-
"""
|
| 131 |
-
Creates all the necessary loggers for evaluation tracking.
|
| 132 |
-
|
| 133 |
-
Args:
|
| 134 |
-
output_path (str): Path to save the results. If not provided, the results won't be saved.
|
| 135 |
-
hub_results_org (str): The Hugging Face organization to push the results to. If not provided, the results will be pushed to the owner of the Hugging Face token.
|
| 136 |
-
hub_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will be pushed to `lm-eval-results`.
|
| 137 |
-
details_repo_name (str): The name of the Hugging Face repository to push the details to. If not provided, the results will be pushed to `lm-eval-results`.
|
| 138 |
-
result_repo_name (str): The name of the Hugging Face repository to push the results to. If not provided, the results will not be pushed and will be found in the details_hub_repo.
|
| 139 |
-
push_results_to_hub (bool): Whether to push the results to the Hugging Face hub.
|
| 140 |
-
push_samples_to_hub (bool): Whether to push the samples to the Hugging Face hub.
|
| 141 |
-
public_repo (bool): Whether to push the results to a public or private repository.
|
| 142 |
-
token (str): Token to use when pushing to the Hugging Face hub. This token should have write access to `hub_results_org`.
|
| 143 |
-
leaderboard_url (str): URL to the leaderboard on the Hugging Face hub on the dataset card.
|
| 144 |
-
point_of_contact (str): Contact information on the Hugging Face hub dataset card.
|
| 145 |
-
gated (bool): Whether to gate the repository.
|
| 146 |
-
"""
|
| 147 |
-
self.general_config_tracker = GeneralConfigTracker()
|
| 148 |
-
|
| 149 |
-
self.output_path = output_path
|
| 150 |
-
self.push_results_to_hub = push_results_to_hub
|
| 151 |
-
self.push_samples_to_hub = push_samples_to_hub
|
| 152 |
-
self.public_repo = public_repo
|
| 153 |
-
self.leaderboard_url = leaderboard_url
|
| 154 |
-
self.point_of_contact = point_of_contact
|
| 155 |
-
self.api = HfApi(token=token) if token else None
|
| 156 |
-
self.gated_repo = gated
|
| 157 |
-
|
| 158 |
-
if not self.api and (push_results_to_hub or push_samples_to_hub):
|
| 159 |
-
raise ValueError(
|
| 160 |
-
"Hugging Face token is not defined, but 'push_results_to_hub' or 'push_samples_to_hub' is set to True. "
|
| 161 |
-
"Please provide a valid Hugging Face token by setting the HF_TOKEN environment variable."
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
if (
|
| 165 |
-
self.api
|
| 166 |
-
and hub_results_org == ""
|
| 167 |
-
and (push_results_to_hub or push_samples_to_hub)
|
| 168 |
-
):
|
| 169 |
-
hub_results_org = self.api.whoami()["name"]
|
| 170 |
-
eval_logger.warning(
|
| 171 |
-
f"hub_results_org was not specified. Results will be pushed to '{hub_results_org}'."
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
if hub_repo_name == "":
|
| 175 |
-
details_repo_name = (
|
| 176 |
-
details_repo_name if details_repo_name != "" else "lm-eval-results"
|
| 177 |
-
)
|
| 178 |
-
results_repo_name = (
|
| 179 |
-
results_repo_name if results_repo_name != "" else details_repo_name
|
| 180 |
-
)
|
| 181 |
-
else:
|
| 182 |
-
details_repo_name = hub_repo_name
|
| 183 |
-
results_repo_name = hub_repo_name
|
| 184 |
-
eval_logger.warning(
|
| 185 |
-
"hub_repo_name was specified. Both details and results will be pushed to the same repository. Using hub_repo_name is no longer recommended, details_repo_name and results_repo_name should be used instead."
|
| 186 |
-
)
|
| 187 |
-
|
| 188 |
-
self.details_repo = f"{hub_results_org}/{details_repo_name}"
|
| 189 |
-
self.details_repo_private = f"{hub_results_org}/{details_repo_name}-private"
|
| 190 |
-
self.results_repo = f"{hub_results_org}/{results_repo_name}"
|
| 191 |
-
self.results_repo_private = f"{hub_results_org}/{results_repo_name}-private"
|
| 192 |
-
|
| 193 |
-
def save_results_aggregated(
|
| 194 |
-
self,
|
| 195 |
-
results: dict,
|
| 196 |
-
samples: dict,
|
| 197 |
-
) -> None:
|
| 198 |
-
"""
|
| 199 |
-
Saves the aggregated results and samples to the output path and pushes them to the Hugging Face hub if requested.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
results (dict): The aggregated results to save.
|
| 203 |
-
samples (dict): The samples results to save.
|
| 204 |
-
"""
|
| 205 |
-
self.general_config_tracker.log_end_time()
|
| 206 |
-
|
| 207 |
-
if self.output_path:
|
| 208 |
-
try:
|
| 209 |
-
eval_logger.info("Saving results aggregated")
|
| 210 |
-
|
| 211 |
-
# calculate cumulative hash for each task - only if samples are provided
|
| 212 |
-
task_hashes = {}
|
| 213 |
-
if samples:
|
| 214 |
-
for task_name, task_samples in samples.items():
|
| 215 |
-
sample_hashes = [
|
| 216 |
-
s["doc_hash"] + s["prompt_hash"] + s["target_hash"]
|
| 217 |
-
for s in task_samples
|
| 218 |
-
]
|
| 219 |
-
task_hashes[task_name] = hash_string("".join(sample_hashes))
|
| 220 |
-
|
| 221 |
-
# update initial results dict
|
| 222 |
-
results.update({"task_hashes": task_hashes})
|
| 223 |
-
results.update(asdict(self.general_config_tracker))
|
| 224 |
-
dumped = json.dumps(
|
| 225 |
-
results,
|
| 226 |
-
indent=2,
|
| 227 |
-
default=handle_non_serializable,
|
| 228 |
-
ensure_ascii=False,
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
path = Path(self.output_path if self.output_path else Path.cwd())
|
| 232 |
-
self.date_id = datetime.now().isoformat().replace(":", "-")
|
| 233 |
-
if path.suffix == ".json":
|
| 234 |
-
path.parent.mkdir(parents=True, exist_ok=True)
|
| 235 |
-
file_results_aggregated = path.with_name(
|
| 236 |
-
f"{path.stem}_{self.date_id}.json"
|
| 237 |
-
)
|
| 238 |
-
else:
|
| 239 |
-
path = path.joinpath(
|
| 240 |
-
self.general_config_tracker.model_name_sanitized
|
| 241 |
-
)
|
| 242 |
-
path.mkdir(parents=True, exist_ok=True)
|
| 243 |
-
file_results_aggregated = path.joinpath(
|
| 244 |
-
f"results_{self.date_id}.json"
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
file_results_aggregated.open("w", encoding="utf-8").write(dumped)
|
| 248 |
-
|
| 249 |
-
if self.api and self.push_results_to_hub:
|
| 250 |
-
repo_id = (
|
| 251 |
-
self.results_repo
|
| 252 |
-
if self.public_repo
|
| 253 |
-
else self.results_repo_private
|
| 254 |
-
)
|
| 255 |
-
self.api.create_repo(
|
| 256 |
-
repo_id=repo_id,
|
| 257 |
-
repo_type="dataset",
|
| 258 |
-
private=not self.public_repo,
|
| 259 |
-
exist_ok=True,
|
| 260 |
-
)
|
| 261 |
-
self.api.upload_file(
|
| 262 |
-
repo_id=repo_id,
|
| 263 |
-
path_or_fileobj=str(file_results_aggregated),
|
| 264 |
-
path_in_repo=os.path.join(
|
| 265 |
-
self.general_config_tracker.model_name,
|
| 266 |
-
file_results_aggregated.name,
|
| 267 |
-
),
|
| 268 |
-
repo_type="dataset",
|
| 269 |
-
commit_message=f"Adding aggregated results for {self.general_config_tracker.model_name}",
|
| 270 |
-
)
|
| 271 |
-
eval_logger.info(
|
| 272 |
-
"Successfully pushed aggregated results to the Hugging Face Hub. "
|
| 273 |
-
f"You can find them at: {repo_id}"
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
except Exception as e:
|
| 277 |
-
eval_logger.warning("Could not save results aggregated")
|
| 278 |
-
eval_logger.info(repr(e))
|
| 279 |
-
else:
|
| 280 |
-
eval_logger.info(
|
| 281 |
-
"Output path not provided, skipping saving results aggregated"
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
def save_results_samples(
|
| 285 |
-
self,
|
| 286 |
-
task_name: str,
|
| 287 |
-
samples: dict,
|
| 288 |
-
) -> None:
|
| 289 |
-
"""
|
| 290 |
-
Saves the samples results to the output path and pushes them to the Hugging Face hub if requested.
|
| 291 |
-
|
| 292 |
-
Args:
|
| 293 |
-
task_name (str): The task name to save the samples for.
|
| 294 |
-
samples (dict): The samples results to save.
|
| 295 |
-
"""
|
| 296 |
-
if self.output_path:
|
| 297 |
-
try:
|
| 298 |
-
eval_logger.info(f"Saving per-sample results for: {task_name}")
|
| 299 |
-
|
| 300 |
-
path = Path(self.output_path if self.output_path else Path.cwd())
|
| 301 |
-
if path.suffix == ".json":
|
| 302 |
-
path = path.parent
|
| 303 |
-
else:
|
| 304 |
-
path = path.joinpath(
|
| 305 |
-
self.general_config_tracker.model_name_sanitized
|
| 306 |
-
)
|
| 307 |
-
path.mkdir(parents=True, exist_ok=True)
|
| 308 |
-
|
| 309 |
-
file_results_samples = path.joinpath(
|
| 310 |
-
f"samples_{task_name}_{self.date_id}.jsonl"
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
for sample in samples:
|
| 314 |
-
# we first need to sanitize arguments and resps
|
| 315 |
-
# otherwise we won't be able to load the dataset
|
| 316 |
-
# using the datasets library
|
| 317 |
-
arguments = {}
|
| 318 |
-
for i, arg in enumerate(sample["arguments"]):
|
| 319 |
-
arguments[f"gen_args_{i}"] = {}
|
| 320 |
-
for j, tmp in enumerate(arg):
|
| 321 |
-
arguments[f"gen_args_{i}"][f"arg_{j}"] = tmp
|
| 322 |
-
|
| 323 |
-
sample["resps"] = sanitize_list(sample["resps"])
|
| 324 |
-
sample["filtered_resps"] = sanitize_list(sample["filtered_resps"])
|
| 325 |
-
sample["arguments"] = arguments
|
| 326 |
-
sample["target"] = str(sample["target"])
|
| 327 |
-
|
| 328 |
-
sample_dump = (
|
| 329 |
-
json.dumps(
|
| 330 |
-
sample,
|
| 331 |
-
default=handle_non_serializable,
|
| 332 |
-
ensure_ascii=False,
|
| 333 |
-
)
|
| 334 |
-
+ "\n"
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
with open(file_results_samples, "a", encoding="utf-8") as f:
|
| 338 |
-
f.write(sample_dump)
|
| 339 |
-
|
| 340 |
-
if self.api and self.push_samples_to_hub:
|
| 341 |
-
repo_id = (
|
| 342 |
-
self.details_repo
|
| 343 |
-
if self.public_repo
|
| 344 |
-
else self.details_repo_private
|
| 345 |
-
)
|
| 346 |
-
self.api.create_repo(
|
| 347 |
-
repo_id=repo_id,
|
| 348 |
-
repo_type="dataset",
|
| 349 |
-
private=not self.public_repo,
|
| 350 |
-
exist_ok=True,
|
| 351 |
-
)
|
| 352 |
-
try:
|
| 353 |
-
if self.gated_repo:
|
| 354 |
-
headers = build_hf_headers()
|
| 355 |
-
r = get_session().put(
|
| 356 |
-
url=f"https://huggingface.co/api/datasets/{repo_id}/settings",
|
| 357 |
-
headers=headers,
|
| 358 |
-
json={"gated": "auto"},
|
| 359 |
-
)
|
| 360 |
-
hf_raise_for_status(r)
|
| 361 |
-
except Exception as e:
|
| 362 |
-
eval_logger.warning("Could not gate the repository")
|
| 363 |
-
eval_logger.info(repr(e))
|
| 364 |
-
self.api.upload_folder(
|
| 365 |
-
repo_id=repo_id,
|
| 366 |
-
folder_path=str(path),
|
| 367 |
-
path_in_repo=self.general_config_tracker.model_name_sanitized,
|
| 368 |
-
repo_type="dataset",
|
| 369 |
-
commit_message=f"Adding samples results for {task_name} to {self.general_config_tracker.model_name}",
|
| 370 |
-
)
|
| 371 |
-
eval_logger.info(
|
| 372 |
-
f"Successfully pushed sample results for task: {task_name} to the Hugging Face Hub. "
|
| 373 |
-
f"You can find them at: {repo_id}"
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
except Exception as e:
|
| 377 |
-
eval_logger.warning("Could not save sample results")
|
| 378 |
-
eval_logger.info(repr(e))
|
| 379 |
-
else:
|
| 380 |
-
eval_logger.info("Output path not provided, skipping saving sample results")
|
| 381 |
-
|
| 382 |
-
def recreate_metadata_card(self) -> None:
|
| 383 |
-
"""
|
| 384 |
-
Creates a metadata card for the evaluation results dataset and pushes it to the Hugging Face hub.
|
| 385 |
-
"""
|
| 386 |
-
|
| 387 |
-
eval_logger.info("Recreating metadata card")
|
| 388 |
-
repo_id = self.details_repo if self.public_repo else self.details_repo_private
|
| 389 |
-
|
| 390 |
-
files_in_repo = self.api.list_repo_files(repo_id=repo_id, repo_type="dataset")
|
| 391 |
-
results_files = get_results_filenames(files_in_repo)
|
| 392 |
-
sample_files = get_sample_results_filenames(files_in_repo)
|
| 393 |
-
|
| 394 |
-
# Build a dictionary to store the latest evaluation datetime for:
|
| 395 |
-
# - Each tested model and its aggregated results
|
| 396 |
-
# - Each task and sample results, if existing
|
| 397 |
-
# i.e. {
|
| 398 |
-
# "org__model_name__gsm8k": "2021-09-01T12:00:00",
|
| 399 |
-
# "org__model_name__ifeval": "2021-09-01T12:00:00",
|
| 400 |
-
# "org__model_name__results": "2021-09-01T12:00:00"
|
| 401 |
-
# }
|
| 402 |
-
latest_task_results_datetime = defaultdict(lambda: datetime.min.isoformat())
|
| 403 |
-
|
| 404 |
-
for file_path in sample_files:
|
| 405 |
-
file_path = Path(file_path)
|
| 406 |
-
filename = file_path.name
|
| 407 |
-
model_name = file_path.parent
|
| 408 |
-
task_name = get_file_task_name(filename)
|
| 409 |
-
results_datetime = get_file_datetime(filename)
|
| 410 |
-
task_name_sanitized = sanitize_task_name(task_name)
|
| 411 |
-
# Results and sample results for the same model and task will have the same datetime
|
| 412 |
-
samples_key = f"{model_name}__{task_name_sanitized}"
|
| 413 |
-
results_key = f"{model_name}__results"
|
| 414 |
-
latest_datetime = max(
|
| 415 |
-
latest_task_results_datetime[samples_key],
|
| 416 |
-
results_datetime,
|
| 417 |
-
)
|
| 418 |
-
latest_task_results_datetime[samples_key] = latest_datetime
|
| 419 |
-
latest_task_results_datetime[results_key] = max(
|
| 420 |
-
latest_task_results_datetime[results_key],
|
| 421 |
-
latest_datetime,
|
| 422 |
-
)
|
| 423 |
-
|
| 424 |
-
# Create metadata card
|
| 425 |
-
card_metadata = MetadataConfigs()
|
| 426 |
-
|
| 427 |
-
# Add the latest aggregated results to the metadata card for easy access
|
| 428 |
-
for file_path in results_files:
|
| 429 |
-
file_path = Path(file_path)
|
| 430 |
-
results_filename = file_path.name
|
| 431 |
-
model_name = file_path.parent
|
| 432 |
-
eval_date = get_file_datetime(results_filename)
|
| 433 |
-
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
| 434 |
-
results_filename = Path("**") / Path(results_filename).name
|
| 435 |
-
config_name = f"{model_name}__results"
|
| 436 |
-
sanitized_last_eval_date_results = re.sub(
|
| 437 |
-
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
if eval_date_sanitized == sanitized_last_eval_date_results:
|
| 441 |
-
# Ensure that all results files are listed in the metadata card
|
| 442 |
-
current_results = card_metadata.get(config_name, {"data_files": []})
|
| 443 |
-
current_results["data_files"].append(
|
| 444 |
-
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
| 445 |
-
)
|
| 446 |
-
card_metadata[config_name] = current_results
|
| 447 |
-
# If the results file is the newest, update the "latest" field in the metadata card
|
| 448 |
-
card_metadata[config_name]["data_files"].append(
|
| 449 |
-
{"split": "latest", "path": [str(results_filename)]}
|
| 450 |
-
)
|
| 451 |
-
|
| 452 |
-
# Add the tasks details configs
|
| 453 |
-
for file_path in sample_files:
|
| 454 |
-
file_path = Path(file_path)
|
| 455 |
-
filename = file_path.name
|
| 456 |
-
model_name = file_path.parent
|
| 457 |
-
task_name = get_file_task_name(filename)
|
| 458 |
-
eval_date = get_file_datetime(filename)
|
| 459 |
-
task_name_sanitized = sanitize_task_name(task_name)
|
| 460 |
-
eval_date_sanitized = re.sub(r"[^\w\.]", "_", eval_date)
|
| 461 |
-
results_filename = Path("**") / Path(filename).name
|
| 462 |
-
config_name = f"{model_name}__{task_name_sanitized}"
|
| 463 |
-
sanitized_last_eval_date_results = re.sub(
|
| 464 |
-
r"[^\w\.]", "_", latest_task_results_datetime[config_name]
|
| 465 |
-
)
|
| 466 |
-
if eval_date_sanitized == sanitized_last_eval_date_results:
|
| 467 |
-
# Ensure that all sample results files are listed in the metadata card
|
| 468 |
-
current_details_for_task = card_metadata.get(
|
| 469 |
-
config_name, {"data_files": []}
|
| 470 |
-
)
|
| 471 |
-
current_details_for_task["data_files"].append(
|
| 472 |
-
{"split": eval_date_sanitized, "path": [str(results_filename)]}
|
| 473 |
-
)
|
| 474 |
-
card_metadata[config_name] = current_details_for_task
|
| 475 |
-
# If the samples results file is the newest, update the "latest" field in the metadata card
|
| 476 |
-
card_metadata[config_name]["data_files"].append(
|
| 477 |
-
{"split": "latest", "path": [str(results_filename)]}
|
| 478 |
-
)
|
| 479 |
-
|
| 480 |
-
# Get latest results and extract info to update metadata card examples
|
| 481 |
-
latest_datetime = max(latest_task_results_datetime.values())
|
| 482 |
-
latest_model_name = max(
|
| 483 |
-
latest_task_results_datetime, key=lambda k: latest_task_results_datetime[k]
|
| 484 |
-
)
|
| 485 |
-
last_results_file = [
|
| 486 |
-
f for f in results_files if latest_datetime.replace(":", "-") in f
|
| 487 |
-
][0]
|
| 488 |
-
last_results_file_path = hf_hub_url(
|
| 489 |
-
repo_id=repo_id, filename=last_results_file, repo_type="dataset"
|
| 490 |
-
)
|
| 491 |
-
latest_results_file = load_dataset(
|
| 492 |
-
"json", data_files=last_results_file_path, split="train"
|
| 493 |
-
)
|
| 494 |
-
results_dict = latest_results_file["results"][0]
|
| 495 |
-
new_dictionary = {"all": results_dict}
|
| 496 |
-
new_dictionary.update(results_dict)
|
| 497 |
-
results_string = json.dumps(new_dictionary, indent=4)
|
| 498 |
-
|
| 499 |
-
dataset_summary = (
|
| 500 |
-
"Dataset automatically created during the evaluation run of model "
|
| 501 |
-
)
|
| 502 |
-
if self.general_config_tracker.model_source == "hf":
|
| 503 |
-
dataset_summary += f"[{self.general_config_tracker.model_name}](https://huggingface.co/{self.general_config_tracker.model_name})\n"
|
| 504 |
-
else:
|
| 505 |
-
dataset_summary += f"{self.general_config_tracker.model_name}\n"
|
| 506 |
-
dataset_summary += (
|
| 507 |
-
f"The dataset is composed of {len(card_metadata) - 1} configuration(s), each one corresponding to one of the evaluated task.\n\n"
|
| 508 |
-
f"The dataset has been created from {len(results_files)} run(s). Each run can be found as a specific split in each "
|
| 509 |
-
'configuration, the split being named using the timestamp of the run.The "train" split is always pointing to the latest results.\n\n'
|
| 510 |
-
'An additional configuration "results" store all the aggregated results of the run.\n\n'
|
| 511 |
-
"To load the details from a run, you can for instance do the following:\n"
|
| 512 |
-
)
|
| 513 |
-
if self.general_config_tracker.model_source == "hf":
|
| 514 |
-
dataset_summary += (
|
| 515 |
-
"```python\nfrom datasets import load_dataset\n"
|
| 516 |
-
f'data = load_dataset(\n\t"{repo_id}",\n\tname="{latest_model_name}",\n\tsplit="latest"\n)\n```\n\n'
|
| 517 |
-
)
|
| 518 |
-
dataset_summary += (
|
| 519 |
-
"## Latest results\n\n"
|
| 520 |
-
f"These are the [latest results from run {latest_datetime}]({last_results_file_path.replace('/resolve/', '/blob/')}) "
|
| 521 |
-
"(note that there might be results for other tasks in the repos if successive evals didn't cover the same tasks. "
|
| 522 |
-
'You find each in the results and the "latest" split for each eval):\n\n'
|
| 523 |
-
f"```python\n{results_string}\n```"
|
| 524 |
-
)
|
| 525 |
-
card_data = DatasetCardData(
|
| 526 |
-
dataset_summary=dataset_summary,
|
| 527 |
-
repo_url=f"https://huggingface.co/{self.general_config_tracker.model_name}",
|
| 528 |
-
pretty_name=f"Evaluation run of {self.general_config_tracker.model_name}",
|
| 529 |
-
leaderboard_url=self.leaderboard_url,
|
| 530 |
-
point_of_contact=self.point_of_contact,
|
| 531 |
-
)
|
| 532 |
-
card_metadata.to_dataset_card_data(card_data)
|
| 533 |
-
card = DatasetCard.from_template(
|
| 534 |
-
card_data,
|
| 535 |
-
pretty_name=card_data.pretty_name,
|
| 536 |
-
)
|
| 537 |
-
card.push_to_hub(repo_id, repo_type="dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/loggers/utils.py
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
import re
|
| 4 |
-
import subprocess
|
| 5 |
-
from importlib.metadata import version
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any, Dict, Optional, Tuple, Union
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
from torch.utils.collect_env import get_pretty_env_info
|
| 11 |
-
from transformers import __version__ as trans_version
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
|
| 18 |
-
"""Remove the ',none' substring from the input_string if it exists at the end.
|
| 19 |
-
|
| 20 |
-
Args:
|
| 21 |
-
input_string (str): The input string from which to remove the ',none' substring.
|
| 22 |
-
|
| 23 |
-
Returns:
|
| 24 |
-
Tuple[str, bool]: A tuple containing the modified input_string with the ',none' substring removed
|
| 25 |
-
and a boolean indicating whether the modification was made (True) or not (False).
|
| 26 |
-
"""
|
| 27 |
-
# Define the pattern to match ',none' at the end of the string
|
| 28 |
-
pattern = re.compile(r",none$")
|
| 29 |
-
|
| 30 |
-
# Use sub() to replace ',none' with an empty string
|
| 31 |
-
result = re.sub(pattern, "", input_string)
|
| 32 |
-
|
| 33 |
-
# check if the input_string changed
|
| 34 |
-
removed = result != input_string
|
| 35 |
-
|
| 36 |
-
return result, removed
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _handle_non_serializable(o: Any) -> Union[int, str, list]:
|
| 40 |
-
"""Handle non-serializable objects by converting them to serializable types.
|
| 41 |
-
|
| 42 |
-
Args:
|
| 43 |
-
o (Any): The object to be handled.
|
| 44 |
-
|
| 45 |
-
Returns:
|
| 46 |
-
Union[int, str, list]: The converted object. If the object is of type np.int64 or np.int32,
|
| 47 |
-
it will be converted to int. If the object is of type set, it will be converted
|
| 48 |
-
to a list. Otherwise, it will be converted to str.
|
| 49 |
-
"""
|
| 50 |
-
if isinstance(o, np.int64) or isinstance(o, np.int32):
|
| 51 |
-
return int(o)
|
| 52 |
-
elif isinstance(o, set):
|
| 53 |
-
return list(o)
|
| 54 |
-
else:
|
| 55 |
-
return str(o)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def get_commit_from_path(repo_path: Union[Path, str]) -> Optional[str]:
|
| 59 |
-
try:
|
| 60 |
-
git_folder = Path(repo_path, ".git")
|
| 61 |
-
if git_folder.is_file():
|
| 62 |
-
git_folder = Path(
|
| 63 |
-
git_folder.parent,
|
| 64 |
-
git_folder.read_text(encoding="utf-8").split("\n")[0].split(" ")[-1],
|
| 65 |
-
)
|
| 66 |
-
if Path(git_folder, "HEAD").exists():
|
| 67 |
-
head_name = (
|
| 68 |
-
Path(git_folder, "HEAD")
|
| 69 |
-
.read_text(encoding="utf-8")
|
| 70 |
-
.split("\n")[0]
|
| 71 |
-
.split(" ")[-1]
|
| 72 |
-
)
|
| 73 |
-
head_ref = Path(git_folder, head_name)
|
| 74 |
-
git_hash = head_ref.read_text(encoding="utf-8").replace("\n", "")
|
| 75 |
-
else:
|
| 76 |
-
git_hash = None
|
| 77 |
-
except Exception as err:
|
| 78 |
-
logger.debug(
|
| 79 |
-
f"Failed to retrieve a Git commit hash from path: {str(repo_path)}. Error: {err}"
|
| 80 |
-
)
|
| 81 |
-
return None
|
| 82 |
-
return git_hash
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def get_git_commit_hash():
|
| 86 |
-
"""
|
| 87 |
-
Gets the git commit hash of your current repo (if it exists).
|
| 88 |
-
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
|
| 89 |
-
"""
|
| 90 |
-
try:
|
| 91 |
-
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip()
|
| 92 |
-
git_hash = git_hash.decode()
|
| 93 |
-
except (subprocess.CalledProcessError, FileNotFoundError):
|
| 94 |
-
# FileNotFoundError occurs when git not installed on system
|
| 95 |
-
git_hash = get_commit_from_path(os.getcwd()) # git hash of repo if exists
|
| 96 |
-
return git_hash
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def add_env_info(storage: Dict[str, Any]):
|
| 100 |
-
try:
|
| 101 |
-
pretty_env_info = get_pretty_env_info()
|
| 102 |
-
except Exception as err:
|
| 103 |
-
pretty_env_info = str(err)
|
| 104 |
-
try:
|
| 105 |
-
lm_eval_version = version("lm_eval")
|
| 106 |
-
except Exception as err:
|
| 107 |
-
lm_eval_version = str(err)
|
| 108 |
-
transformers_version = trans_version
|
| 109 |
-
upper_dir_commit = get_commit_from_path(
|
| 110 |
-
Path(os.getcwd(), "..")
|
| 111 |
-
) # git hash of upper repo if exists
|
| 112 |
-
added_info = {
|
| 113 |
-
"pretty_env_info": pretty_env_info,
|
| 114 |
-
"transformers_version": transformers_version,
|
| 115 |
-
"lm_eval_version": lm_eval_version,
|
| 116 |
-
"upper_git_hash": upper_dir_commit, # in case this repo is submodule
|
| 117 |
-
}
|
| 118 |
-
storage.update(added_info)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def add_tokenizer_info(storage: Dict[str, Any], lm):
|
| 122 |
-
if getattr(lm, "tokenizer", False):
|
| 123 |
-
try:
|
| 124 |
-
tokenizer_info = {
|
| 125 |
-
"tokenizer_pad_token": [
|
| 126 |
-
lm.tokenizer.pad_token,
|
| 127 |
-
str(lm.tokenizer.pad_token_id),
|
| 128 |
-
],
|
| 129 |
-
"tokenizer_eos_token": [
|
| 130 |
-
lm.tokenizer.eos_token,
|
| 131 |
-
str(lm.tokenizer.eos_token_id),
|
| 132 |
-
],
|
| 133 |
-
"tokenizer_bos_token": [
|
| 134 |
-
lm.tokenizer.bos_token,
|
| 135 |
-
str(lm.tokenizer.bos_token_id),
|
| 136 |
-
],
|
| 137 |
-
"eot_token_id": getattr(lm, "eot_token_id", None),
|
| 138 |
-
"max_length": getattr(lm, "max_length", None),
|
| 139 |
-
}
|
| 140 |
-
storage.update(tokenizer_info)
|
| 141 |
-
except Exception as err:
|
| 142 |
-
logger.debug(
|
| 143 |
-
f"Logging detailed tokenizer info failed with {err}, skipping..."
|
| 144 |
-
)
|
| 145 |
-
# seems gguf and textsynth do not have tokenizer
|
| 146 |
-
else:
|
| 147 |
-
logger.debug(
|
| 148 |
-
"LM does not have a 'tokenizer' attribute, not logging tokenizer metadata to results."
|
| 149 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/loggers/wandb_logger.py
DELETED
|
@@ -1,358 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
from typing import Any, Dict, List, Literal, Tuple
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import pandas as pd
|
| 8 |
-
from packaging.version import Version
|
| 9 |
-
|
| 10 |
-
from lm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
logger = logging.getLogger(__name__)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def get_wandb_printer() -> Literal["Printer"]:
|
| 17 |
-
"""Returns a wandb printer instance for pretty stdout."""
|
| 18 |
-
from wandb.sdk.lib.printer import new_printer
|
| 19 |
-
|
| 20 |
-
printer = new_printer()
|
| 21 |
-
return printer
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class WandbLogger:
|
| 25 |
-
def __init__(self, init_args=None, config_args=None) -> None:
|
| 26 |
-
"""Attaches to wandb logger if already initialized. Otherwise, passes init_args to wandb.init() and config_args to wandb.config.update()
|
| 27 |
-
|
| 28 |
-
Args:
|
| 29 |
-
init_args Optional[Dict]: Arguments for init configuration.
|
| 30 |
-
config_args Optional[Dict]: Arguments for config
|
| 31 |
-
|
| 32 |
-
Parse and log the results returned from evaluator.simple_evaluate() with:
|
| 33 |
-
wandb_logger.post_init(results)
|
| 34 |
-
wandb_logger.log_eval_result()
|
| 35 |
-
wandb_logger.log_eval_samples(results["samples"])
|
| 36 |
-
"""
|
| 37 |
-
try:
|
| 38 |
-
import wandb
|
| 39 |
-
|
| 40 |
-
assert Version(wandb.__version__) >= Version("0.13.6")
|
| 41 |
-
if Version(wandb.__version__) < Version("0.13.6"):
|
| 42 |
-
wandb.require("report-editing:v0")
|
| 43 |
-
except Exception as e:
|
| 44 |
-
logger.warning(
|
| 45 |
-
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
|
| 46 |
-
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
|
| 47 |
-
f"{e}"
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
self.wandb_args: Dict[str, Any] = init_args or {}
|
| 51 |
-
self.wandb_config_args: Dict[str, Any] = config_args or {}
|
| 52 |
-
|
| 53 |
-
# pop the step key from the args to save for all logging calls
|
| 54 |
-
self.step = self.wandb_args.pop("step", None)
|
| 55 |
-
|
| 56 |
-
# initialize a W&B run
|
| 57 |
-
if wandb.run is None:
|
| 58 |
-
self.run = wandb.init(**self.wandb_args)
|
| 59 |
-
if self.wandb_config_args:
|
| 60 |
-
self.run.config.update(self.wandb_config_args)
|
| 61 |
-
else:
|
| 62 |
-
self.run = wandb.run
|
| 63 |
-
|
| 64 |
-
self.printer = get_wandb_printer()
|
| 65 |
-
|
| 66 |
-
def post_init(self, results: Dict[str, Any]) -> None:
|
| 67 |
-
self.results: Dict[str, Any] = copy.deepcopy(results)
|
| 68 |
-
self.task_names: List[str] = list(results.get("results", {}).keys())
|
| 69 |
-
self.group_names: List[str] = list(results.get("groups", {}).keys())
|
| 70 |
-
|
| 71 |
-
def _get_config(self) -> Dict[str, Any]:
|
| 72 |
-
"""Get configuration parameters."""
|
| 73 |
-
self.task_configs = self.results.get("configs", {})
|
| 74 |
-
cli_configs = self.results.get("config", {})
|
| 75 |
-
configs = {
|
| 76 |
-
"task_configs": self.task_configs,
|
| 77 |
-
"cli_configs": cli_configs,
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
return configs
|
| 81 |
-
|
| 82 |
-
def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
|
| 83 |
-
"""Sanitize the results dictionary."""
|
| 84 |
-
_results = copy.deepcopy(self.results.get("results", dict()))
|
| 85 |
-
|
| 86 |
-
# Remove None from the metric string name
|
| 87 |
-
tmp_results = copy.deepcopy(_results)
|
| 88 |
-
for task_name in self.task_names:
|
| 89 |
-
task_result = tmp_results.get(task_name, dict())
|
| 90 |
-
for metric_name, metric_value in task_result.items():
|
| 91 |
-
_metric_name, removed = remove_none_pattern(metric_name)
|
| 92 |
-
if removed:
|
| 93 |
-
_results[task_name][_metric_name] = metric_value
|
| 94 |
-
_results[task_name].pop(metric_name)
|
| 95 |
-
|
| 96 |
-
# remove string valued keys from the results dict
|
| 97 |
-
wandb_summary = {}
|
| 98 |
-
for task in self.task_names:
|
| 99 |
-
task_result = _results.get(task, dict())
|
| 100 |
-
for metric_name, metric_value in task_result.items():
|
| 101 |
-
if isinstance(metric_value, str):
|
| 102 |
-
wandb_summary[f"{task}/{metric_name}"] = metric_value
|
| 103 |
-
|
| 104 |
-
for summary_metric, summary_value in wandb_summary.items():
|
| 105 |
-
_task, _summary_metric = summary_metric.split("/")
|
| 106 |
-
_results[_task].pop(_summary_metric)
|
| 107 |
-
|
| 108 |
-
tmp_results = copy.deepcopy(_results)
|
| 109 |
-
for task_name, task_results in tmp_results.items():
|
| 110 |
-
for metric_name, metric_value in task_results.items():
|
| 111 |
-
_results[f"{task_name}/{metric_name}"] = metric_value
|
| 112 |
-
_results[task_name].pop(metric_name)
|
| 113 |
-
for task in self.task_names:
|
| 114 |
-
_results.pop(task)
|
| 115 |
-
|
| 116 |
-
return wandb_summary, _results
|
| 117 |
-
|
| 118 |
-
def _log_results_as_table(self) -> None:
|
| 119 |
-
"""Generate and log evaluation results as a table to W&B."""
|
| 120 |
-
columns = [
|
| 121 |
-
"Version",
|
| 122 |
-
"Filter",
|
| 123 |
-
"num_fewshot",
|
| 124 |
-
"Metric",
|
| 125 |
-
"Value",
|
| 126 |
-
"Stderr",
|
| 127 |
-
]
|
| 128 |
-
|
| 129 |
-
def make_table(columns: List[str], key: str = "results"):
|
| 130 |
-
import wandb
|
| 131 |
-
|
| 132 |
-
table = wandb.Table(columns=columns)
|
| 133 |
-
results = copy.deepcopy(self.results)
|
| 134 |
-
|
| 135 |
-
for k, dic in results.get(key).items():
|
| 136 |
-
if k in self.group_names and not key == "groups":
|
| 137 |
-
continue
|
| 138 |
-
version = results.get("versions").get(k)
|
| 139 |
-
if version == "N/A":
|
| 140 |
-
version = None
|
| 141 |
-
n = results.get("n-shot").get(k)
|
| 142 |
-
|
| 143 |
-
for (mf), v in dic.items():
|
| 144 |
-
m, _, f = mf.partition(",")
|
| 145 |
-
if m.endswith("_stderr"):
|
| 146 |
-
continue
|
| 147 |
-
if m == "alias":
|
| 148 |
-
continue
|
| 149 |
-
|
| 150 |
-
if m + "_stderr" + "," + f in dic:
|
| 151 |
-
se = dic[m + "_stderr" + "," + f]
|
| 152 |
-
if se != "N/A":
|
| 153 |
-
se = "%.4f" % se
|
| 154 |
-
table.add_data(*[k, version, f, n, m, str(v), str(se)])
|
| 155 |
-
else:
|
| 156 |
-
table.add_data(*[k, version, f, n, m, str(v), ""])
|
| 157 |
-
|
| 158 |
-
return table
|
| 159 |
-
|
| 160 |
-
# log the complete eval result to W&B Table
|
| 161 |
-
table = make_table(["Tasks"] + columns, "results")
|
| 162 |
-
self.run.log({"evaluation/eval_results": table}, step=self.step)
|
| 163 |
-
|
| 164 |
-
if "groups" in self.results.keys():
|
| 165 |
-
table = make_table(["Groups"] + columns, "groups")
|
| 166 |
-
self.run.log({"evaluation/group_eval_results": table}, step=self.step)
|
| 167 |
-
|
| 168 |
-
def _log_results_as_artifact(self) -> None:
|
| 169 |
-
"""Log results as JSON artifact to W&B."""
|
| 170 |
-
import wandb
|
| 171 |
-
|
| 172 |
-
dumped = json.dumps(
|
| 173 |
-
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
|
| 174 |
-
)
|
| 175 |
-
artifact = wandb.Artifact("results", type="eval_results")
|
| 176 |
-
with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
|
| 177 |
-
f.write(dumped)
|
| 178 |
-
self.run.log_artifact(artifact)
|
| 179 |
-
|
| 180 |
-
def log_eval_result(self) -> None:
|
| 181 |
-
"""Log evaluation results to W&B."""
|
| 182 |
-
# Log configs to wandb
|
| 183 |
-
configs = self._get_config()
|
| 184 |
-
self.run.config.update(configs, allow_val_change=self.step is not None)
|
| 185 |
-
|
| 186 |
-
wandb_summary, self.wandb_results = self._sanitize_results_dict()
|
| 187 |
-
# update wandb.run.summary with items that were removed
|
| 188 |
-
self.run.summary.update(wandb_summary)
|
| 189 |
-
# Log the evaluation metrics to wandb
|
| 190 |
-
self.run.log(self.wandb_results, step=self.step)
|
| 191 |
-
# Log the evaluation metrics as W&B Table
|
| 192 |
-
self._log_results_as_table()
|
| 193 |
-
# Log the results dict as json to W&B Artifacts
|
| 194 |
-
self._log_results_as_artifact()
|
| 195 |
-
|
| 196 |
-
def _generate_dataset(
|
| 197 |
-
self, data: List[Dict[str, Any]], config: Dict[str, Any]
|
| 198 |
-
) -> pd.DataFrame:
|
| 199 |
-
"""Generate a dataset from evaluation data.
|
| 200 |
-
|
| 201 |
-
Args:
|
| 202 |
-
data (List[Dict[str, Any]]): The data to generate a dataset for.
|
| 203 |
-
config (Dict[str, Any]): The configuration of the task.
|
| 204 |
-
|
| 205 |
-
Returns:
|
| 206 |
-
pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
|
| 207 |
-
"""
|
| 208 |
-
ids = [x["doc_id"] for x in data]
|
| 209 |
-
labels = [x["target"] for x in data]
|
| 210 |
-
instance = [""] * len(ids)
|
| 211 |
-
resps = [""] * len(ids)
|
| 212 |
-
filtered_resps = [""] * len(ids)
|
| 213 |
-
model_outputs = {}
|
| 214 |
-
|
| 215 |
-
metrics_list = config["metric_list"]
|
| 216 |
-
metrics = {}
|
| 217 |
-
for metric in metrics_list:
|
| 218 |
-
metric = metric.get("metric")
|
| 219 |
-
if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
|
| 220 |
-
metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
|
| 221 |
-
if metric in ["byte_perplexity", "bits_per_byte"]:
|
| 222 |
-
metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
|
| 223 |
-
else:
|
| 224 |
-
metrics[f"{metric}_words"] = [x[metric][1] for x in data]
|
| 225 |
-
else:
|
| 226 |
-
metrics[metric] = [x[metric] for x in data]
|
| 227 |
-
|
| 228 |
-
if config["output_type"] == "loglikelihood":
|
| 229 |
-
instance = [x["arguments"][0][0] for x in data]
|
| 230 |
-
labels = [x["arguments"][0][1] for x in data]
|
| 231 |
-
resps = [
|
| 232 |
-
f"log probability of continuation is {x['resps'][0][0][0]} "
|
| 233 |
-
+ "\n\n"
|
| 234 |
-
+ "continuation will {} generated with greedy sampling".format(
|
| 235 |
-
"not be" if not x["resps"][0][0][1] else "be"
|
| 236 |
-
)
|
| 237 |
-
for x in data
|
| 238 |
-
]
|
| 239 |
-
filtered_resps = [
|
| 240 |
-
f"log probability of continuation is {x['filtered_resps'][0][0]} "
|
| 241 |
-
+ "\n\n"
|
| 242 |
-
+ "continuation will {} generated with greedy sampling".format(
|
| 243 |
-
"not be" if not x["filtered_resps"][0][1] else "be"
|
| 244 |
-
)
|
| 245 |
-
for x in data
|
| 246 |
-
]
|
| 247 |
-
elif config["output_type"] == "multiple_choice":
|
| 248 |
-
instance = [x["arguments"][0][0] for x in data]
|
| 249 |
-
choices = [
|
| 250 |
-
"\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
|
| 251 |
-
for x in data
|
| 252 |
-
]
|
| 253 |
-
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
|
| 254 |
-
filtered_resps = [
|
| 255 |
-
np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
|
| 256 |
-
]
|
| 257 |
-
elif config["output_type"] == "loglikelihood_rolling":
|
| 258 |
-
instance = [x["arguments"][0][0] for x in data]
|
| 259 |
-
resps = [x["resps"][0][0] for x in data]
|
| 260 |
-
filtered_resps = [x["filtered_resps"][0] for x in data]
|
| 261 |
-
elif config["output_type"] == "generate_until":
|
| 262 |
-
instance = [x["arguments"][0][0] for x in data]
|
| 263 |
-
resps = [x["resps"][0][0] for x in data]
|
| 264 |
-
filtered_resps = [x["filtered_resps"][0] for x in data]
|
| 265 |
-
|
| 266 |
-
model_outputs["raw_predictions"] = resps
|
| 267 |
-
model_outputs["filtered_predictions"] = filtered_resps
|
| 268 |
-
|
| 269 |
-
df_data = {
|
| 270 |
-
"id": ids,
|
| 271 |
-
"data": instance,
|
| 272 |
-
}
|
| 273 |
-
if config["output_type"] == "multiple_choice":
|
| 274 |
-
df_data["choices"] = choices
|
| 275 |
-
|
| 276 |
-
tmp_data = {
|
| 277 |
-
"input_len": [len(x) for x in instance],
|
| 278 |
-
"labels": labels,
|
| 279 |
-
"output_type": config["output_type"],
|
| 280 |
-
}
|
| 281 |
-
df_data.update(tmp_data)
|
| 282 |
-
df_data.update(model_outputs)
|
| 283 |
-
df_data.update(metrics)
|
| 284 |
-
|
| 285 |
-
return pd.DataFrame(df_data)
|
| 286 |
-
|
| 287 |
-
def _log_samples_as_artifact(
|
| 288 |
-
self, data: List[Dict[str, Any]], task_name: str
|
| 289 |
-
) -> None:
|
| 290 |
-
import wandb
|
| 291 |
-
|
| 292 |
-
# log the samples as an artifact
|
| 293 |
-
dumped = json.dumps(
|
| 294 |
-
data,
|
| 295 |
-
indent=2,
|
| 296 |
-
default=_handle_non_serializable,
|
| 297 |
-
ensure_ascii=False,
|
| 298 |
-
)
|
| 299 |
-
artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
|
| 300 |
-
with artifact.new_file(
|
| 301 |
-
f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
|
| 302 |
-
) as f:
|
| 303 |
-
f.write(dumped)
|
| 304 |
-
self.run.log_artifact(artifact)
|
| 305 |
-
# artifact.wait()
|
| 306 |
-
|
| 307 |
-
def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
|
| 308 |
-
"""Log evaluation samples to W&B.
|
| 309 |
-
|
| 310 |
-
Args:
|
| 311 |
-
samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
|
| 312 |
-
"""
|
| 313 |
-
task_names: List[str] = [
|
| 314 |
-
x for x in self.task_names if x not in self.group_names
|
| 315 |
-
]
|
| 316 |
-
|
| 317 |
-
ungrouped_tasks = []
|
| 318 |
-
tasks_by_groups = {}
|
| 319 |
-
|
| 320 |
-
for task_name in task_names:
|
| 321 |
-
group_names = self.task_configs[task_name].get("group", None)
|
| 322 |
-
if group_names:
|
| 323 |
-
if isinstance(group_names, str):
|
| 324 |
-
group_names = [group_names]
|
| 325 |
-
|
| 326 |
-
for group_name in group_names:
|
| 327 |
-
if not tasks_by_groups.get(group_name):
|
| 328 |
-
tasks_by_groups[group_name] = [task_name]
|
| 329 |
-
else:
|
| 330 |
-
tasks_by_groups[group_name].append(task_name)
|
| 331 |
-
else:
|
| 332 |
-
ungrouped_tasks.append(task_name)
|
| 333 |
-
|
| 334 |
-
for task_name in ungrouped_tasks:
|
| 335 |
-
eval_preds = samples[task_name]
|
| 336 |
-
|
| 337 |
-
# log the samples as a W&B Table
|
| 338 |
-
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
|
| 339 |
-
self.run.log({f"{task_name}_eval_results": df}, step=self.step)
|
| 340 |
-
|
| 341 |
-
# log the samples as a json file as W&B Artifact
|
| 342 |
-
self._log_samples_as_artifact(eval_preds, task_name)
|
| 343 |
-
|
| 344 |
-
for group, grouped_tasks in tasks_by_groups.items():
|
| 345 |
-
grouped_df = pd.DataFrame()
|
| 346 |
-
for task_name in grouped_tasks:
|
| 347 |
-
eval_preds = samples[task_name]
|
| 348 |
-
df = self._generate_dataset(
|
| 349 |
-
eval_preds, self.task_configs.get(task_name)
|
| 350 |
-
)
|
| 351 |
-
df["group"] = group
|
| 352 |
-
df["task"] = task_name
|
| 353 |
-
grouped_df = pd.concat([grouped_df, df], ignore_index=True)
|
| 354 |
-
|
| 355 |
-
# log the samples as a json file as W&B Artifact
|
| 356 |
-
self._log_samples_as_artifact(eval_preds, task_name)
|
| 357 |
-
|
| 358 |
-
self.run.log({f"{group}_eval_results": grouped_df}, step=self.step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/__init__.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
from . import (
|
| 2 |
-
anthropic_llms,
|
| 3 |
-
api_models,
|
| 4 |
-
dummy,
|
| 5 |
-
gguf,
|
| 6 |
-
hf_audiolm,
|
| 7 |
-
hf_steered,
|
| 8 |
-
hf_vlms,
|
| 9 |
-
huggingface,
|
| 10 |
-
ibm_watsonx_ai,
|
| 11 |
-
mamba_lm,
|
| 12 |
-
nemo_lm,
|
| 13 |
-
neuralmagic,
|
| 14 |
-
neuron_optimum,
|
| 15 |
-
openai_completions,
|
| 16 |
-
optimum_ipex,
|
| 17 |
-
optimum_lm,
|
| 18 |
-
sglang_causallms,
|
| 19 |
-
sglang_generate_API,
|
| 20 |
-
textsynth,
|
| 21 |
-
vllm_causallms,
|
| 22 |
-
vllm_vlms,
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# TODO: implement __all__
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
# enable hf hub transfer if available
|
| 31 |
-
import hf_transfer # type: ignore # noqa
|
| 32 |
-
import huggingface_hub.constants # type: ignore
|
| 33 |
-
|
| 34 |
-
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
| 35 |
-
except ImportError:
|
| 36 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/anthropic_llms.py
DELETED
|
@@ -1,367 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import os
|
| 3 |
-
from functools import cached_property
|
| 4 |
-
from typing import Any, Dict, List, Tuple, Union
|
| 5 |
-
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
|
| 8 |
-
from lm_eval.api.model import LM
|
| 9 |
-
from lm_eval.api.registry import register_model
|
| 10 |
-
from lm_eval.models.openai_completions import LocalCompletionsAPI
|
| 11 |
-
from lm_eval.models.utils import handle_stop_sequences, retry_on_specific_exceptions
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
eval_logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def anthropic_completion(
|
| 18 |
-
client, #: anthropic.Anthropic,
|
| 19 |
-
model: str,
|
| 20 |
-
prompt: str,
|
| 21 |
-
max_tokens_to_sample: int,
|
| 22 |
-
temperature: float,
|
| 23 |
-
stop: List[str],
|
| 24 |
-
**kwargs: Any,
|
| 25 |
-
) -> str:
|
| 26 |
-
"""Wrapper function around the Anthropic completion API client with exponential back-off
|
| 27 |
-
in case of RateLimitError.
|
| 28 |
-
|
| 29 |
-
params:
|
| 30 |
-
client: anthropic.Anthropic
|
| 31 |
-
Anthropic API client
|
| 32 |
-
model: str
|
| 33 |
-
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
|
| 34 |
-
prompt: str
|
| 35 |
-
Prompt to feed to the model
|
| 36 |
-
max_tokens_to_sample: int
|
| 37 |
-
Maximum number of tokens to sample from the model
|
| 38 |
-
temperature: float
|
| 39 |
-
Sampling temperature
|
| 40 |
-
stop: List[str]
|
| 41 |
-
List of stop sequences
|
| 42 |
-
kwargs: Any
|
| 43 |
-
Additional model_args to pass to the API client
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
try:
|
| 47 |
-
import anthropic
|
| 48 |
-
except ModuleNotFoundError as exception:
|
| 49 |
-
raise type(exception)(
|
| 50 |
-
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 51 |
-
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 52 |
-
)
|
| 53 |
-
|
| 54 |
-
def _exception_callback(e: Exception, sleep_time: float) -> None:
|
| 55 |
-
eval_logger.warning(
|
| 56 |
-
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
@retry_on_specific_exceptions(
|
| 60 |
-
on_exceptions=[anthropic.RateLimitError],
|
| 61 |
-
max_retries=None, # retry forever, consider changing
|
| 62 |
-
on_exception_callback=_exception_callback,
|
| 63 |
-
)
|
| 64 |
-
def completion():
|
| 65 |
-
response = client.completions.create(
|
| 66 |
-
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
|
| 67 |
-
model=model,
|
| 68 |
-
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
|
| 69 |
-
# (e.g. gsm8k's ":") may truncate a lot of the input.
|
| 70 |
-
stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
|
| 71 |
-
max_tokens_to_sample=max_tokens_to_sample,
|
| 72 |
-
temperature=temperature,
|
| 73 |
-
**kwargs,
|
| 74 |
-
)
|
| 75 |
-
return response.completion
|
| 76 |
-
|
| 77 |
-
return completion()
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def anthropic_chat(
|
| 81 |
-
client, #: anthropic.Anthropic,
|
| 82 |
-
model: str,
|
| 83 |
-
prompt: str,
|
| 84 |
-
max_tokens: int,
|
| 85 |
-
temperature: float,
|
| 86 |
-
stop: List[str],
|
| 87 |
-
**kwargs: Any,
|
| 88 |
-
) -> str:
|
| 89 |
-
"""Wrapper function around the Anthropic completion API client with exponential back-off
|
| 90 |
-
in case of RateLimitError.
|
| 91 |
-
|
| 92 |
-
params:
|
| 93 |
-
client: anthropic.Anthropic
|
| 94 |
-
Anthropic API client
|
| 95 |
-
model: str
|
| 96 |
-
Anthropic model e.g. 'claude-3-opus-20240229', 'claude-3-sonnet-20240229'
|
| 97 |
-
prompt: str
|
| 98 |
-
Prompt to feed to the model
|
| 99 |
-
max_tokens: int
|
| 100 |
-
Maximum number of tokens to sample from the model
|
| 101 |
-
temperature: float
|
| 102 |
-
Sampling temperature
|
| 103 |
-
stop: List[str]
|
| 104 |
-
List of stop sequences
|
| 105 |
-
kwargs: Any
|
| 106 |
-
Additional model_args to pass to the API client
|
| 107 |
-
"""
|
| 108 |
-
|
| 109 |
-
try:
|
| 110 |
-
import anthropic
|
| 111 |
-
except ModuleNotFoundError as exception:
|
| 112 |
-
raise type(exception)(
|
| 113 |
-
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 114 |
-
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
def _exception_callback(e: Exception, sleep_time: float) -> None:
|
| 118 |
-
eval_logger.warning(
|
| 119 |
-
f"RateLimitError occurred: {e.__cause__}\n Retrying in {sleep_time} seconds"
|
| 120 |
-
)
|
| 121 |
-
|
| 122 |
-
@retry_on_specific_exceptions(
|
| 123 |
-
on_exceptions=[
|
| 124 |
-
anthropic.RateLimitError,
|
| 125 |
-
anthropic.APIConnectionError,
|
| 126 |
-
anthropic.APIStatusError,
|
| 127 |
-
],
|
| 128 |
-
max_retries=None, # retry forever, consider changing
|
| 129 |
-
on_exception_callback=_exception_callback,
|
| 130 |
-
)
|
| 131 |
-
def messages():
|
| 132 |
-
response = client.messages.create(
|
| 133 |
-
model=model,
|
| 134 |
-
max_tokens=max_tokens,
|
| 135 |
-
temperature=temperature,
|
| 136 |
-
messages=[{"role": "user", "content": f"{prompt}"}],
|
| 137 |
-
**kwargs,
|
| 138 |
-
)
|
| 139 |
-
return response.content[0].text
|
| 140 |
-
|
| 141 |
-
return messages()
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
@register_model("anthropic-completions")
|
| 145 |
-
class AnthropicLM(LM):
|
| 146 |
-
REQ_CHUNK_SIZE = 20 # TODO: not used
|
| 147 |
-
|
| 148 |
-
def __init__(
|
| 149 |
-
self,
|
| 150 |
-
batch_size: int = 1,
|
| 151 |
-
model: str = "claude-2.0",
|
| 152 |
-
max_tokens_to_sample: int = 256,
|
| 153 |
-
temperature: float = 0, # defaults to 1
|
| 154 |
-
**kwargs, # top_p, top_k, etc.
|
| 155 |
-
) -> None:
|
| 156 |
-
"""Anthropic API wrapper.
|
| 157 |
-
|
| 158 |
-
:param model: str
|
| 159 |
-
Anthropic model e.g. 'claude-instant-v1', 'claude-2'
|
| 160 |
-
:param max_tokens_to_sample: int
|
| 161 |
-
Maximum number of tokens to sample from the model
|
| 162 |
-
:param temperature: float
|
| 163 |
-
Sampling temperature
|
| 164 |
-
:param kwargs: Any
|
| 165 |
-
Additional model_args to pass to the API client
|
| 166 |
-
"""
|
| 167 |
-
super().__init__()
|
| 168 |
-
|
| 169 |
-
try:
|
| 170 |
-
import anthropic
|
| 171 |
-
except ModuleNotFoundError as exception:
|
| 172 |
-
raise type(exception)(
|
| 173 |
-
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 174 |
-
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
self.model = model
|
| 178 |
-
# defaults to os.environ.get("ANTHROPIC_API_KEY")
|
| 179 |
-
self.client = anthropic.Anthropic()
|
| 180 |
-
self.temperature = temperature
|
| 181 |
-
self.max_tokens_to_sample = max_tokens_to_sample
|
| 182 |
-
self.tokenizer = self.client.get_tokenizer()
|
| 183 |
-
self.kwargs = kwargs
|
| 184 |
-
|
| 185 |
-
@property
|
| 186 |
-
def eot_token_id(self):
|
| 187 |
-
# Not sure but anthropic.HUMAN_PROMPT ?
|
| 188 |
-
raise NotImplementedError("No idea about anthropic tokenization.")
|
| 189 |
-
|
| 190 |
-
@property
|
| 191 |
-
def max_length(self) -> int:
|
| 192 |
-
return 2048
|
| 193 |
-
|
| 194 |
-
@property
|
| 195 |
-
def max_gen_toks(self) -> int:
|
| 196 |
-
return self.max_tokens_to_sample
|
| 197 |
-
|
| 198 |
-
@property
|
| 199 |
-
def batch_size(self):
|
| 200 |
-
# Isn't used because we override _loglikelihood_tokens
|
| 201 |
-
raise NotImplementedError("No support for logits.")
|
| 202 |
-
|
| 203 |
-
@property
|
| 204 |
-
def device(self):
|
| 205 |
-
# Isn't used because we override _loglikelihood_tokens
|
| 206 |
-
raise NotImplementedError("No support for logits.")
|
| 207 |
-
|
| 208 |
-
def tok_encode(self, string: str) -> List[int]:
|
| 209 |
-
return self.tokenizer.encode(string).ids
|
| 210 |
-
|
| 211 |
-
def tok_decode(self, tokens: List[int]) -> str:
|
| 212 |
-
return self.tokenizer.decode(tokens)
|
| 213 |
-
|
| 214 |
-
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
|
| 215 |
-
raise NotImplementedError("No support for logits.")
|
| 216 |
-
|
| 217 |
-
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
|
| 218 |
-
try:
|
| 219 |
-
import anthropic
|
| 220 |
-
except ModuleNotFoundError as exception:
|
| 221 |
-
raise type(exception)(
|
| 222 |
-
"attempted to use 'anthropic' LM type, but package `anthropic` is not installed. \
|
| 223 |
-
please install anthropic via `pip install 'lm-eval[anthropic]'` or `pip install -e '.[anthropic]'`",
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
if not requests:
|
| 227 |
-
return []
|
| 228 |
-
|
| 229 |
-
_requests: List[Tuple[str, dict]] = [req.args for req in requests]
|
| 230 |
-
|
| 231 |
-
res = []
|
| 232 |
-
for request in tqdm(_requests, disable=disable_tqdm):
|
| 233 |
-
try:
|
| 234 |
-
inp = request[0]
|
| 235 |
-
request_args = request[1]
|
| 236 |
-
# generation_kwargs
|
| 237 |
-
until = request_args.get("until")
|
| 238 |
-
max_gen_toks = request_args.get("max_gen_toks", self.max_length)
|
| 239 |
-
temperature = request_args.get("temperature", self.temperature)
|
| 240 |
-
response = anthropic_completion(
|
| 241 |
-
client=self.client,
|
| 242 |
-
model=self.model,
|
| 243 |
-
prompt=inp,
|
| 244 |
-
max_tokens_to_sample=max_gen_toks,
|
| 245 |
-
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
|
| 246 |
-
stop=until, # type: ignore
|
| 247 |
-
**self.kwargs,
|
| 248 |
-
)
|
| 249 |
-
res.append(response)
|
| 250 |
-
|
| 251 |
-
self.cache_hook.add_partial("generate_until", request, response)
|
| 252 |
-
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
|
| 253 |
-
eval_logger.critical(f"Server unreachable: {e.__cause__}")
|
| 254 |
-
break
|
| 255 |
-
except anthropic.APIStatusError as e: # type: ignore # noqa: F821
|
| 256 |
-
eval_logger.critical(f"API error {e.status_code}: {e.message}")
|
| 257 |
-
break
|
| 258 |
-
|
| 259 |
-
return res
|
| 260 |
-
|
| 261 |
-
def _model_call(self, inps):
|
| 262 |
-
# Isn't used because we override _loglikelihood_tokens
|
| 263 |
-
raise NotImplementedError()
|
| 264 |
-
|
| 265 |
-
def _model_generate(self, context, max_length, eos_token_id):
|
| 266 |
-
# Isn't used because we override generate_until
|
| 267 |
-
raise NotImplementedError()
|
| 268 |
-
|
| 269 |
-
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 270 |
-
raise NotImplementedError("No support for logits.")
|
| 271 |
-
|
| 272 |
-
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 273 |
-
raise NotImplementedError("No support for logits.")
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
@register_model("anthropic-chat", "anthropic-chat-completions")
|
| 277 |
-
class AnthropicChat(LocalCompletionsAPI):
|
| 278 |
-
def __init__(
|
| 279 |
-
self,
|
| 280 |
-
base_url="https://api.anthropic.com/v1/messages",
|
| 281 |
-
tokenizer_backend=None,
|
| 282 |
-
**kwargs,
|
| 283 |
-
):
|
| 284 |
-
super().__init__(
|
| 285 |
-
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
|
| 286 |
-
)
|
| 287 |
-
eval_logger.warning(
|
| 288 |
-
"Chat completions does not support batching. Defaulting to batch size 1."
|
| 289 |
-
)
|
| 290 |
-
self._batch_size = 1
|
| 291 |
-
self.anthropic_version = "2023-06-01"
|
| 292 |
-
eval_logger.warning(
|
| 293 |
-
f"Using Anthropic Version: {self.anthropic_version}. Confirm the current version here: https://docs.anthropic.com/en/api/versioning"
|
| 294 |
-
)
|
| 295 |
-
|
| 296 |
-
@cached_property
|
| 297 |
-
def api_key(self):
|
| 298 |
-
"""Override this property to return the API key for the API request."""
|
| 299 |
-
key = os.environ.get("ANTHROPIC_API_KEY", None)
|
| 300 |
-
if key is None:
|
| 301 |
-
raise ValueError(
|
| 302 |
-
"API key not found. Please set the ANTHROPIC_API_KEY environment variable."
|
| 303 |
-
)
|
| 304 |
-
return key
|
| 305 |
-
|
| 306 |
-
@cached_property
|
| 307 |
-
def header(self):
|
| 308 |
-
return {
|
| 309 |
-
"x-api-key": f"{self.api_key}",
|
| 310 |
-
"anthropic-version": self.anthropic_version,
|
| 311 |
-
}
|
| 312 |
-
|
| 313 |
-
def _create_payload(
|
| 314 |
-
self,
|
| 315 |
-
messages: List[Dict],
|
| 316 |
-
generate=True,
|
| 317 |
-
gen_kwargs: dict = None,
|
| 318 |
-
eos="\n\nHuman:",
|
| 319 |
-
**kwargs,
|
| 320 |
-
) -> dict:
|
| 321 |
-
system = (
|
| 322 |
-
messages[0].get("content") if messages[0].get("role") == "system" else None
|
| 323 |
-
)
|
| 324 |
-
if system:
|
| 325 |
-
messages = messages[1:]
|
| 326 |
-
gen_kwargs.pop("do_sample", False)
|
| 327 |
-
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
|
| 328 |
-
temperature = gen_kwargs.pop("temperature", 0)
|
| 329 |
-
stop = handle_stop_sequences(gen_kwargs.pop("until", ["\n\nHuman:"]), eos=eos)
|
| 330 |
-
if not isinstance(stop, list):
|
| 331 |
-
stop = [stop]
|
| 332 |
-
out = {
|
| 333 |
-
"messages": messages,
|
| 334 |
-
"model": self.model,
|
| 335 |
-
"max_tokens": max_tokens,
|
| 336 |
-
"temperature": temperature,
|
| 337 |
-
"stop_sequences": stop,
|
| 338 |
-
**gen_kwargs,
|
| 339 |
-
}
|
| 340 |
-
if system:
|
| 341 |
-
out["system"] = system
|
| 342 |
-
return out
|
| 343 |
-
|
| 344 |
-
def parse_generations(
|
| 345 |
-
self, outputs: Union[Dict, List[Dict]], **kwargs
|
| 346 |
-
) -> List[str]:
|
| 347 |
-
res = []
|
| 348 |
-
if not isinstance(outputs, list):
|
| 349 |
-
outputs = [outputs]
|
| 350 |
-
for out in outputs:
|
| 351 |
-
for choices in out["content"]:
|
| 352 |
-
res.append(choices["text"])
|
| 353 |
-
return res
|
| 354 |
-
|
| 355 |
-
def tok_encode(
|
| 356 |
-
self,
|
| 357 |
-
string: str,
|
| 358 |
-
left_truncate_len=None,
|
| 359 |
-
add_special_tokens=None,
|
| 360 |
-
**kwargs,
|
| 361 |
-
) -> List[str]:
|
| 362 |
-
return [string]
|
| 363 |
-
|
| 364 |
-
def loglikelihood(self, requests, **kwargs):
|
| 365 |
-
raise NotImplementedError(
|
| 366 |
-
"Anthropic Chat Completions API does not support the return of loglikelihood"
|
| 367 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/api_models.py
DELETED
|
@@ -1,799 +0,0 @@
|
|
| 1 |
-
import abc
|
| 2 |
-
import asyncio
|
| 3 |
-
import copy
|
| 4 |
-
import itertools
|
| 5 |
-
import json
|
| 6 |
-
import logging
|
| 7 |
-
from functools import cached_property
|
| 8 |
-
from typing import (
|
| 9 |
-
TYPE_CHECKING,
|
| 10 |
-
Any,
|
| 11 |
-
Awaitable,
|
| 12 |
-
Callable,
|
| 13 |
-
Dict,
|
| 14 |
-
Iterable,
|
| 15 |
-
List,
|
| 16 |
-
Literal,
|
| 17 |
-
NamedTuple,
|
| 18 |
-
Optional,
|
| 19 |
-
Tuple,
|
| 20 |
-
Union,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
try:
|
| 25 |
-
import requests
|
| 26 |
-
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
| 27 |
-
from tenacity import RetryError, retry, stop_after_attempt, wait_exponential
|
| 28 |
-
from tqdm import tqdm
|
| 29 |
-
from tqdm.asyncio import tqdm_asyncio
|
| 30 |
-
except ModuleNotFoundError:
|
| 31 |
-
pass
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
import base64
|
| 35 |
-
from importlib.util import find_spec
|
| 36 |
-
from io import BytesIO
|
| 37 |
-
|
| 38 |
-
from lm_eval import utils
|
| 39 |
-
from lm_eval.api.instance import Instance
|
| 40 |
-
from lm_eval.api.model import TemplateLM
|
| 41 |
-
from lm_eval.models.utils import Collator, chunks, configure_pad_token
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
if TYPE_CHECKING:
|
| 45 |
-
from PIL import Image
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
eval_logger = logging.getLogger(__name__)
|
| 49 |
-
|
| 50 |
-
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# utility class to keep track of json encoded chats
|
| 54 |
-
class JsonChatStr(NamedTuple):
|
| 55 |
-
prompt: str
|
| 56 |
-
|
| 57 |
-
def encode(self, encoding):
|
| 58 |
-
return self.prompt.encode(encoding)
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def create_image_prompt(
|
| 62 |
-
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
|
| 63 |
-
) -> dict:
|
| 64 |
-
"""
|
| 65 |
-
|
| 66 |
-
Parameters
|
| 67 |
-
----------
|
| 68 |
-
img : list[PIL.Image.Image]
|
| 69 |
-
The list of images to encode to base64
|
| 70 |
-
chat : dict
|
| 71 |
-
fmt : str, optional
|
| 72 |
-
Any format Pillow understands (e.g. "PNG", "JPEG").
|
| 73 |
-
Defaults to "PNG".
|
| 74 |
-
|
| 75 |
-
Returns
|
| 76 |
-
-------
|
| 77 |
-
dict
|
| 78 |
-
"""
|
| 79 |
-
images = []
|
| 80 |
-
for img in imgs:
|
| 81 |
-
buf = BytesIO()
|
| 82 |
-
img.save(buf, format=fmt)
|
| 83 |
-
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 84 |
-
img_dict = {
|
| 85 |
-
"type": "image_url",
|
| 86 |
-
"image_url": {"url": f"data:image/png;base64,{img_b64}", "detail": "auto"},
|
| 87 |
-
}
|
| 88 |
-
images.append(img_dict)
|
| 89 |
-
|
| 90 |
-
# chat is in format of list[dict["role": "user"/"system", "content": str, "type": "text"],...]
|
| 91 |
-
# with images, we need "content" to be a list of dicts with "type" and "text"/"image_url"
|
| 92 |
-
# currently we do not support few-shots so only one user message
|
| 93 |
-
# text content also has <image> placeholders, which apparently is not necessary for API class (confirm)
|
| 94 |
-
|
| 95 |
-
if isinstance(chat[-1]["content"], list):
|
| 96 |
-
chat[-1]["content"] = images + chat[-1]["content"]
|
| 97 |
-
else:
|
| 98 |
-
text_content = {"type": "text", "text": chat[-1]["content"]}
|
| 99 |
-
chat[-1]["content"] = images + [text_content]
|
| 100 |
-
chat[-1].pop("type")
|
| 101 |
-
return chat
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
class TemplateAPI(TemplateLM):
|
| 105 |
-
MULTIMODAL = True
|
| 106 |
-
|
| 107 |
-
def __init__(
|
| 108 |
-
self,
|
| 109 |
-
model: str = None,
|
| 110 |
-
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
|
| 111 |
-
base_url: str = None,
|
| 112 |
-
tokenizer: Optional[str] = None,
|
| 113 |
-
# Loglikelihood tasks require a tokenizer to calculate context lengths,
|
| 114 |
-
# however the requests can be sent as a string if the API doesn't support token inputs.
|
| 115 |
-
# use tokenized_requests=False
|
| 116 |
-
tokenizer_backend: Optional[
|
| 117 |
-
Literal["tiktoken", "huggingface", "None", "none"]
|
| 118 |
-
] = "huggingface",
|
| 119 |
-
truncate: bool = False,
|
| 120 |
-
# number of concurrent requests. More useful if not batching
|
| 121 |
-
num_concurrent: int = 1,
|
| 122 |
-
max_retries: int = 3,
|
| 123 |
-
max_gen_toks: int = 256,
|
| 124 |
-
batch_size: Union[str, int] = 1,
|
| 125 |
-
seed: int = 1234,
|
| 126 |
-
max_length: Optional[int] = 2048,
|
| 127 |
-
add_bos_token: bool = False,
|
| 128 |
-
custom_prefix_token_id: int = None,
|
| 129 |
-
# send the requests as tokens or strings
|
| 130 |
-
tokenized_requests: bool = True,
|
| 131 |
-
trust_remote_code: bool = False,
|
| 132 |
-
revision: Optional[str] = "main",
|
| 133 |
-
use_fast_tokenizer: bool = True,
|
| 134 |
-
verify_certificate: bool = True,
|
| 135 |
-
eos_string: str = None,
|
| 136 |
-
# timeout in seconds
|
| 137 |
-
timeout: int = 300,
|
| 138 |
-
max_images: int = 1,
|
| 139 |
-
**kwargs,
|
| 140 |
-
) -> None:
|
| 141 |
-
super().__init__()
|
| 142 |
-
missing_packages = [
|
| 143 |
-
pkg
|
| 144 |
-
for pkg in ["aiohttp", "tqdm", "tenacity", "requests"]
|
| 145 |
-
if find_spec(pkg) is None
|
| 146 |
-
]
|
| 147 |
-
if missing_packages:
|
| 148 |
-
raise ModuleNotFoundError(
|
| 149 |
-
f"Attempted to use an API model, but the required packages {missing_packages} are not installed. "
|
| 150 |
-
'Please install these via `pip install lm-eval[api]` or `pip install -e ."[api]"`'
|
| 151 |
-
)
|
| 152 |
-
self.model = model or pretrained
|
| 153 |
-
self.base_url = base_url
|
| 154 |
-
self.tokenizer = tokenizer
|
| 155 |
-
if not isinstance(batch_size, int) and "auto" in batch_size:
|
| 156 |
-
eval_logger.warning(
|
| 157 |
-
"Automatic batch size is not supported for API models. Defaulting to batch size 1."
|
| 158 |
-
)
|
| 159 |
-
elif int(batch_size) > 1:
|
| 160 |
-
eval_logger.warning(
|
| 161 |
-
"Batch size > 1 detected. Ensure your API supports batched requests with varying total sequence lengths."
|
| 162 |
-
)
|
| 163 |
-
self._batch_size = int(batch_size) if batch_size != "auto" else 1
|
| 164 |
-
self._truncate = truncate
|
| 165 |
-
self._max_gen_toks = int(max_gen_toks)
|
| 166 |
-
self._seed = int(seed)
|
| 167 |
-
# max_length - 1 as we always have 1 token for generation
|
| 168 |
-
eval_logger.info(f"Using max length {max_length} - 1")
|
| 169 |
-
self.max_length = max_length - 1
|
| 170 |
-
if int(num_concurrent) <= 1:
|
| 171 |
-
eval_logger.info(
|
| 172 |
-
"Concurrent requests are disabled. To enable concurrent requests, set `num_concurrent` > 1."
|
| 173 |
-
)
|
| 174 |
-
self._concurrent = int(num_concurrent)
|
| 175 |
-
self.tokenizer_backend = (
|
| 176 |
-
None if tokenizer_backend in ("None", "none") else tokenizer_backend
|
| 177 |
-
)
|
| 178 |
-
self.add_bos_token = add_bos_token
|
| 179 |
-
self.custom_prefix_token_id = custom_prefix_token_id
|
| 180 |
-
self.tokenized_requests = tokenized_requests
|
| 181 |
-
self.max_retries = int(max_retries)
|
| 182 |
-
self.verify_certificate = verify_certificate
|
| 183 |
-
self._eos_string = eos_string
|
| 184 |
-
self.timeout = int(timeout)
|
| 185 |
-
self.max_images = int(max_images)
|
| 186 |
-
|
| 187 |
-
eval_logger.info(f"Using tokenizer {self.tokenizer_backend}")
|
| 188 |
-
if self.tokenizer_backend is None:
|
| 189 |
-
self.tokenizer = None
|
| 190 |
-
self.tokenized_requests = False
|
| 191 |
-
else:
|
| 192 |
-
if self.tokenizer is None:
|
| 193 |
-
if self.tokenizer_backend == "huggingface":
|
| 194 |
-
import transformers
|
| 195 |
-
|
| 196 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 197 |
-
self.tokenizer if self.tokenizer else self.model,
|
| 198 |
-
trust_remote_code=trust_remote_code,
|
| 199 |
-
revision=revision,
|
| 200 |
-
use_fast=use_fast_tokenizer,
|
| 201 |
-
)
|
| 202 |
-
# Not used as the API will handle padding but to mirror the behavior of the HFLM
|
| 203 |
-
self.tokenizer = configure_pad_token(self.tokenizer)
|
| 204 |
-
elif self.tokenizer_backend == "tiktoken":
|
| 205 |
-
try:
|
| 206 |
-
import tiktoken
|
| 207 |
-
|
| 208 |
-
self.tokenizer = tiktoken.encoding_for_model(self.model)
|
| 209 |
-
except ModuleNotFoundError as e:
|
| 210 |
-
raise ModuleNotFoundError(
|
| 211 |
-
"Attempted to use 'openai' LM type, but the package `tiktoken` is not installed. "
|
| 212 |
-
"Please install it via `pip install lm-eval[api]` or `pip install -e .[api]`."
|
| 213 |
-
) from e
|
| 214 |
-
if "openai" not in self.base_url:
|
| 215 |
-
eval_logger.warning(
|
| 216 |
-
f"Passed `base_url={self.base_url}` but using (OpenAI) Tiktoken tokenizer backend. "
|
| 217 |
-
"Pass `tokenizer_backend=huggingface` and provide the HF tokenizer name if your model does not use Tiktoken."
|
| 218 |
-
)
|
| 219 |
-
else:
|
| 220 |
-
import transformers
|
| 221 |
-
|
| 222 |
-
assert isinstance(tokenizer, str), "tokenizer must be a string"
|
| 223 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 224 |
-
tokenizer,
|
| 225 |
-
trust_remote_code=trust_remote_code,
|
| 226 |
-
revision=revision,
|
| 227 |
-
use_fast=use_fast_tokenizer,
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
@abc.abstractmethod
|
| 231 |
-
def _create_payload(
|
| 232 |
-
self,
|
| 233 |
-
messages: Union[List[List[int]], List[dict], List[str], str],
|
| 234 |
-
*,
|
| 235 |
-
generate: bool = True,
|
| 236 |
-
gen_kwargs: Optional[dict] = None,
|
| 237 |
-
seed: int = 1234,
|
| 238 |
-
eos: str = None,
|
| 239 |
-
**kwargs,
|
| 240 |
-
) -> dict:
|
| 241 |
-
"""This method is responsible for creating the json payload that will be sent to the API."""
|
| 242 |
-
raise NotImplementedError
|
| 243 |
-
|
| 244 |
-
def create_message(
|
| 245 |
-
self,
|
| 246 |
-
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
| 247 |
-
generate=False,
|
| 248 |
-
) -> Union[List[List[int]], List[dict], List[str], str]:
|
| 249 |
-
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
|
| 250 |
-
if isinstance(messages[0], JsonChatStr):
|
| 251 |
-
# for chat completions we need to decode the json string to list[dict,...]
|
| 252 |
-
assert self._batch_size == 1, (
|
| 253 |
-
"non-tokenized chat requests are only supported with batch_size=1"
|
| 254 |
-
)
|
| 255 |
-
# list[dict["role":..., "content":...],...]
|
| 256 |
-
return json.loads(messages[0].prompt)
|
| 257 |
-
|
| 258 |
-
if not self.tokenized_requests:
|
| 259 |
-
# if messages are tokenized:
|
| 260 |
-
if isinstance(messages[0][0], int):
|
| 261 |
-
# assuming decoding is lossless. However, this is only for loglikelihood requests
|
| 262 |
-
# as we need to compute the context length. For generations, we don't need to tokenize.
|
| 263 |
-
messages = self.decode_batch(messages)
|
| 264 |
-
if self._batch_size <= 1:
|
| 265 |
-
# if batch is 1 return str
|
| 266 |
-
return messages[0]
|
| 267 |
-
else:
|
| 268 |
-
# list[str,...]
|
| 269 |
-
return messages
|
| 270 |
-
|
| 271 |
-
# list[list[int], ...]
|
| 272 |
-
return messages
|
| 273 |
-
|
| 274 |
-
@staticmethod
|
| 275 |
-
@abc.abstractmethod
|
| 276 |
-
def parse_logprobs(
|
| 277 |
-
outputs: Union[Any, List[Any]],
|
| 278 |
-
tokens: List[List[int]] = None,
|
| 279 |
-
ctxlen: List[int] = None,
|
| 280 |
-
**kwargs,
|
| 281 |
-
) -> List[Tuple[float, bool]]:
|
| 282 |
-
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
|
| 283 |
-
raise NotImplementedError
|
| 284 |
-
|
| 285 |
-
@staticmethod
|
| 286 |
-
@abc.abstractmethod
|
| 287 |
-
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
|
| 288 |
-
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
|
| 289 |
-
raise NotImplementedError
|
| 290 |
-
|
| 291 |
-
@cached_property
|
| 292 |
-
def api_key(self) -> str:
|
| 293 |
-
"""Override this property to return the API key for the API request."""
|
| 294 |
-
return ""
|
| 295 |
-
|
| 296 |
-
@cached_property
|
| 297 |
-
def header(self) -> dict:
|
| 298 |
-
"""Override this property to return the headers for the API request."""
|
| 299 |
-
return {"Authorization": f"Bearer {self.api_key}"}
|
| 300 |
-
|
| 301 |
-
@property
|
| 302 |
-
def tokenizer_name(self) -> str:
|
| 303 |
-
"""Must be defined for LM subclasses which implement Chat Templating.
|
| 304 |
-
Should return the name of the tokenizer or chat template used.
|
| 305 |
-
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
|
| 306 |
-
"""
|
| 307 |
-
return ""
|
| 308 |
-
|
| 309 |
-
def apply_chat_template(
|
| 310 |
-
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 311 |
-
) -> Union[str, JsonChatStr]:
|
| 312 |
-
"""Applies a chat template to a list of chat history between user and model."""
|
| 313 |
-
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
|
| 314 |
-
return self.tokenizer.apply_chat_template(
|
| 315 |
-
chat_history,
|
| 316 |
-
tokenize=False,
|
| 317 |
-
add_generation_prompt=add_generation_prompt,
|
| 318 |
-
continue_final_message=not add_generation_prompt,
|
| 319 |
-
)
|
| 320 |
-
else:
|
| 321 |
-
# bit of a hack. We'll load back before sending to the API
|
| 322 |
-
return JsonChatStr(
|
| 323 |
-
json.dumps(
|
| 324 |
-
[{**item, "type": "text"} for item in chat_history],
|
| 325 |
-
ensure_ascii=False,
|
| 326 |
-
)
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
@cached_property
|
| 330 |
-
def eot_token_id(self) -> Optional[int]:
|
| 331 |
-
if self.tokenizer is None:
|
| 332 |
-
return None
|
| 333 |
-
else:
|
| 334 |
-
if self.tokenizer_backend == "huggingface":
|
| 335 |
-
return self.tokenizer.eos_token_id
|
| 336 |
-
elif self.tokenizer_backend == "tiktoken":
|
| 337 |
-
return self.tokenizer.eot_token
|
| 338 |
-
|
| 339 |
-
@cached_property
|
| 340 |
-
def eos_string(self) -> Optional[str]:
|
| 341 |
-
if self._eos_string:
|
| 342 |
-
return self._eos_string
|
| 343 |
-
elif self.tokenizer is not None:
|
| 344 |
-
if self.tokenizer_backend == "huggingface":
|
| 345 |
-
return self.tokenizer.eos_token
|
| 346 |
-
elif self.tokenizer_backend == "tiktoken":
|
| 347 |
-
return self.tokenizer.decode([self.tokenizer.eot_token])
|
| 348 |
-
else:
|
| 349 |
-
eval_logger.warning(
|
| 350 |
-
"Cannot determine EOS string to pass to stop sequence. Manually set by passing `eos_string` to model_args."
|
| 351 |
-
)
|
| 352 |
-
return None
|
| 353 |
-
|
| 354 |
-
@cached_property
|
| 355 |
-
def prefix_token_id(self) -> Optional[int]:
|
| 356 |
-
if self.tokenizer is None:
|
| 357 |
-
return None
|
| 358 |
-
else:
|
| 359 |
-
if self.custom_prefix_token_id is not None:
|
| 360 |
-
return self.custom_prefix_token_id
|
| 361 |
-
if self.tokenizer_backend == "huggingface":
|
| 362 |
-
if self.tokenizer.bos_token_id is not None:
|
| 363 |
-
return self.tokenizer.bos_token_id
|
| 364 |
-
return self.tokenizer.eos_token_id
|
| 365 |
-
else:
|
| 366 |
-
return self.tokenizer.eot_token
|
| 367 |
-
|
| 368 |
-
def tok_encode(
|
| 369 |
-
self,
|
| 370 |
-
string: str,
|
| 371 |
-
left_truncate_len: int = None,
|
| 372 |
-
add_special_tokens: bool = False,
|
| 373 |
-
truncation: bool = False,
|
| 374 |
-
**kwargs,
|
| 375 |
-
) -> Union[List[List[int]], List[int], List[str]]:
|
| 376 |
-
if self.tokenizer_backend is None:
|
| 377 |
-
return [string]
|
| 378 |
-
elif self.tokenizer_backend == "huggingface":
|
| 379 |
-
# by default for CausalLM - false or self.add_bos_token is set
|
| 380 |
-
if not add_special_tokens:
|
| 381 |
-
add_special_tokens = False or self.add_bos_token
|
| 382 |
-
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
|
| 383 |
-
string,
|
| 384 |
-
add_special_tokens=add_special_tokens,
|
| 385 |
-
truncation=truncation,
|
| 386 |
-
return_attention_mask=False,
|
| 387 |
-
).input_ids
|
| 388 |
-
|
| 389 |
-
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 390 |
-
if left_truncate_len:
|
| 391 |
-
if not isinstance(string, str):
|
| 392 |
-
encoding = [enc[-left_truncate_len:] for enc in encoding]
|
| 393 |
-
else:
|
| 394 |
-
encoding = encoding[-left_truncate_len:]
|
| 395 |
-
|
| 396 |
-
return encoding
|
| 397 |
-
|
| 398 |
-
else:
|
| 399 |
-
try:
|
| 400 |
-
encoding = self.tokenizer.encode(string)
|
| 401 |
-
except Exception:
|
| 402 |
-
encoding = self.tokenizer.encode_batch(string)
|
| 403 |
-
return encoding
|
| 404 |
-
|
| 405 |
-
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
|
| 406 |
-
if self.tokenizer_backend == "huggingface":
|
| 407 |
-
return self.tokenizer.batch_decode(tokens)
|
| 408 |
-
elif self.tokenizer_backend == "tiktoken":
|
| 409 |
-
return self.tokenizer.decode_batch(tokens)
|
| 410 |
-
|
| 411 |
-
def model_call(
|
| 412 |
-
self,
|
| 413 |
-
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
| 414 |
-
*,
|
| 415 |
-
generate: bool = True,
|
| 416 |
-
gen_kwargs: Optional[Dict] = None,
|
| 417 |
-
**kwargs,
|
| 418 |
-
) -> Optional[dict]:
|
| 419 |
-
# !!! Copy: shared dict for each request, need new object !!!
|
| 420 |
-
gen_kwargs = copy.deepcopy(gen_kwargs)
|
| 421 |
-
try:
|
| 422 |
-
response = requests.post(
|
| 423 |
-
self.base_url,
|
| 424 |
-
json=self._create_payload(
|
| 425 |
-
self.create_message(messages),
|
| 426 |
-
generate=generate,
|
| 427 |
-
gen_kwargs=gen_kwargs,
|
| 428 |
-
seed=self._seed,
|
| 429 |
-
eos=self.eos_string,
|
| 430 |
-
**kwargs,
|
| 431 |
-
),
|
| 432 |
-
headers=self.header,
|
| 433 |
-
verify=self.verify_certificate,
|
| 434 |
-
)
|
| 435 |
-
if not response.ok:
|
| 436 |
-
eval_logger.warning(
|
| 437 |
-
f"API request failed with error message: {response.text}. Retrying..."
|
| 438 |
-
)
|
| 439 |
-
response.raise_for_status()
|
| 440 |
-
return response.json()
|
| 441 |
-
except RetryError:
|
| 442 |
-
eval_logger.error(
|
| 443 |
-
"API request failed after multiple retries. Please check the API status."
|
| 444 |
-
)
|
| 445 |
-
return None
|
| 446 |
-
|
| 447 |
-
async def amodel_call(
|
| 448 |
-
self,
|
| 449 |
-
session: ClientSession,
|
| 450 |
-
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
|
| 451 |
-
*,
|
| 452 |
-
generate: bool = True,
|
| 453 |
-
cache_keys: list = None,
|
| 454 |
-
ctxlens: Optional[List[int]] = None,
|
| 455 |
-
gen_kwargs: Optional[Dict] = None,
|
| 456 |
-
**kwargs,
|
| 457 |
-
) -> Union[List[str], List[Tuple[float, bool]], None]:
|
| 458 |
-
# !!! Copy: shared dict for each request, need new object !!!
|
| 459 |
-
gen_kwargs = copy.deepcopy(gen_kwargs)
|
| 460 |
-
payload = self._create_payload(
|
| 461 |
-
self.create_message(messages),
|
| 462 |
-
generate=generate,
|
| 463 |
-
gen_kwargs=gen_kwargs,
|
| 464 |
-
seed=self._seed,
|
| 465 |
-
**kwargs,
|
| 466 |
-
)
|
| 467 |
-
cache_method = "generate_until" if generate else "loglikelihood"
|
| 468 |
-
try:
|
| 469 |
-
async with session.post(
|
| 470 |
-
self.base_url,
|
| 471 |
-
json=payload,
|
| 472 |
-
headers=self.header,
|
| 473 |
-
) as response:
|
| 474 |
-
if not response.ok:
|
| 475 |
-
error_text = await response.text()
|
| 476 |
-
eval_logger.warning(
|
| 477 |
-
f"API request failed with error message: {error_text}. Retrying..."
|
| 478 |
-
)
|
| 479 |
-
# raising exception will retry the request
|
| 480 |
-
response.raise_for_status()
|
| 481 |
-
outputs = await response.json()
|
| 482 |
-
answers = (
|
| 483 |
-
self.parse_generations(
|
| 484 |
-
outputs=outputs,
|
| 485 |
-
)
|
| 486 |
-
if generate
|
| 487 |
-
else self.parse_logprobs(
|
| 488 |
-
outputs=outputs,
|
| 489 |
-
tokens=messages,
|
| 490 |
-
ctxlens=ctxlens,
|
| 491 |
-
)
|
| 492 |
-
)
|
| 493 |
-
if cache_keys:
|
| 494 |
-
for res, cache in zip(answers, cache_keys):
|
| 495 |
-
self.cache_hook.add_partial(cache_method, cache, res)
|
| 496 |
-
return answers
|
| 497 |
-
# If the retries also fail
|
| 498 |
-
except RetryError:
|
| 499 |
-
eval_logger.error(
|
| 500 |
-
"API request failed after multiple retries. Please check the API status."
|
| 501 |
-
)
|
| 502 |
-
return None
|
| 503 |
-
|
| 504 |
-
def batch_loglikelihood_requests(
|
| 505 |
-
self, chunks: Iterable[List[LogLikelihoodInputs]]
|
| 506 |
-
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
|
| 507 |
-
inputs = []
|
| 508 |
-
ctxlens = []
|
| 509 |
-
cache_keys = []
|
| 510 |
-
for chunk in chunks:
|
| 511 |
-
for cache_key, context_enc, continuation_enc in chunk:
|
| 512 |
-
# max_length - 1 as we always have 1 token for generation
|
| 513 |
-
inp = (context_enc + continuation_enc)[-self.max_length :]
|
| 514 |
-
if len(inp) < len(context_enc + continuation_enc):
|
| 515 |
-
eval_logger.warning(
|
| 516 |
-
f"Context length ({len(context_enc)}) + continuation length ({len(continuation_enc)}) > max_length ({self.max_length}). Left truncating context."
|
| 517 |
-
)
|
| 518 |
-
ctxlen = len(context_enc) - max(
|
| 519 |
-
0, len(context_enc) + len(continuation_enc) - self.max_length
|
| 520 |
-
)
|
| 521 |
-
|
| 522 |
-
inputs.append(inp)
|
| 523 |
-
ctxlens.append(ctxlen)
|
| 524 |
-
cache_keys.append(cache_key)
|
| 525 |
-
return inputs, ctxlens, cache_keys
|
| 526 |
-
|
| 527 |
-
async def get_batched_requests(
|
| 528 |
-
self,
|
| 529 |
-
requests: list,
|
| 530 |
-
cache_keys: list,
|
| 531 |
-
*,
|
| 532 |
-
generate: bool = True,
|
| 533 |
-
ctxlens: List[int] = None,
|
| 534 |
-
**kwargs,
|
| 535 |
-
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
|
| 536 |
-
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
| 537 |
-
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
| 538 |
-
async with ClientSession(
|
| 539 |
-
connector=conn, timeout=ClientTimeout(total=self.timeout)
|
| 540 |
-
) as session:
|
| 541 |
-
retry_: Callable[..., Awaitable[Any]] = retry(
|
| 542 |
-
stop=stop_after_attempt(self.max_retries),
|
| 543 |
-
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
| 544 |
-
reraise=True,
|
| 545 |
-
)(self.amodel_call)
|
| 546 |
-
# Create tasks for each batch of request
|
| 547 |
-
tasks = [
|
| 548 |
-
asyncio.create_task(
|
| 549 |
-
retry_(
|
| 550 |
-
session=session,
|
| 551 |
-
messages=message,
|
| 552 |
-
cache_keys=cache_key,
|
| 553 |
-
generate=generate,
|
| 554 |
-
ctxlens=ctxlen,
|
| 555 |
-
**kwargs,
|
| 556 |
-
)
|
| 557 |
-
)
|
| 558 |
-
for message, cache_key, ctxlen in zip(
|
| 559 |
-
chunks(requests, n=self._batch_size),
|
| 560 |
-
chunks(cache_keys, n=self._batch_size),
|
| 561 |
-
chunks(ctxlens, n=self._batch_size),
|
| 562 |
-
)
|
| 563 |
-
]
|
| 564 |
-
|
| 565 |
-
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
| 566 |
-
|
| 567 |
-
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
|
| 568 |
-
assert self.tokenizer is not None, (
|
| 569 |
-
"Tokenizer is required for loglikelihood tasks to compute context lengths."
|
| 570 |
-
)
|
| 571 |
-
res = []
|
| 572 |
-
|
| 573 |
-
def _collate(req: LogLikelihoodInputs):
|
| 574 |
-
"""Defines the key for the sorted method"""
|
| 575 |
-
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 576 |
-
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 577 |
-
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 578 |
-
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 579 |
-
# automatic adaptive batches much much easier to implement
|
| 580 |
-
# - any OOMs will happen right away rather than near the end
|
| 581 |
-
|
| 582 |
-
toks = req[1] + req[2]
|
| 583 |
-
return -len(toks), tuple(toks)
|
| 584 |
-
|
| 585 |
-
re_ord = Collator(
|
| 586 |
-
requests,
|
| 587 |
-
sort_fn=_collate,
|
| 588 |
-
group_by=None,
|
| 589 |
-
)
|
| 590 |
-
# if concurrent then we'll batch in the async context
|
| 591 |
-
chunked = re_ord.get_batched(n=self._batch_size if self._concurrent <= 1 else 0)
|
| 592 |
-
if self._concurrent <= 1:
|
| 593 |
-
pbar = tqdm(desc="Requesting API", total=len(requests))
|
| 594 |
-
for chunk in chunked:
|
| 595 |
-
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests([chunk])
|
| 596 |
-
|
| 597 |
-
outputs = retry(
|
| 598 |
-
stop=stop_after_attempt(self.max_retries),
|
| 599 |
-
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
| 600 |
-
reraise=True,
|
| 601 |
-
)(self.model_call)(messages=inputs, generate=False)
|
| 602 |
-
if isinstance(outputs, dict):
|
| 603 |
-
outputs = [outputs]
|
| 604 |
-
for answer_, cache_key in zip(
|
| 605 |
-
self.parse_logprobs(
|
| 606 |
-
outputs=outputs, tokens=inputs, ctxlens=ctxlens
|
| 607 |
-
),
|
| 608 |
-
cache_keys,
|
| 609 |
-
):
|
| 610 |
-
if answer_ is not None:
|
| 611 |
-
res.append(answer_)
|
| 612 |
-
# cache requests that aren't from a loglikelihood_rolling request
|
| 613 |
-
if cache_key is not None:
|
| 614 |
-
self.cache_hook.add_partial(
|
| 615 |
-
"loglikelihood", cache_key, answer_
|
| 616 |
-
)
|
| 617 |
-
pbar.update(1)
|
| 618 |
-
else:
|
| 619 |
-
inputs, ctxlens, cache_keys = self.batch_loglikelihood_requests(chunked)
|
| 620 |
-
res = itertools.chain.from_iterable(
|
| 621 |
-
asyncio.run(
|
| 622 |
-
self.get_batched_requests(
|
| 623 |
-
inputs, cache_keys, generate=False, ctxlens=ctxlens
|
| 624 |
-
)
|
| 625 |
-
)
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
return re_ord.get_original(res)
|
| 629 |
-
|
| 630 |
-
def generate_until(
|
| 631 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 632 |
-
) -> List[str]:
|
| 633 |
-
res = []
|
| 634 |
-
|
| 635 |
-
def _collate_gen(_requests):
|
| 636 |
-
# sort by the length of the non-tokenized contexts
|
| 637 |
-
return -len(_requests[0])
|
| 638 |
-
|
| 639 |
-
# Let the API deal with tokenization
|
| 640 |
-
if len(requests[0].args) > 2:
|
| 641 |
-
assert self.tokenizer is None, (
|
| 642 |
-
"tokenizer is not supported for multimodal requests yet!"
|
| 643 |
-
)
|
| 644 |
-
eval_logger.info(
|
| 645 |
-
f"Using max_images {self.max_images}. Set in the model args."
|
| 646 |
-
)
|
| 647 |
-
requests, all_gen_kwargs, auxiliary_args = zip(
|
| 648 |
-
*(req.args for req in requests)
|
| 649 |
-
)
|
| 650 |
-
requests = tuple(
|
| 651 |
-
JsonChatStr(
|
| 652 |
-
json.dumps(
|
| 653 |
-
create_image_prompt(
|
| 654 |
-
y["visual"][: self.max_images], json.loads(x.prompt)
|
| 655 |
-
)
|
| 656 |
-
)
|
| 657 |
-
)
|
| 658 |
-
for x, y in zip(requests, auxiliary_args)
|
| 659 |
-
)
|
| 660 |
-
else:
|
| 661 |
-
requests, all_gen_kwargs = zip(*(req.args for req in requests))
|
| 662 |
-
if self.tokenized_requests:
|
| 663 |
-
encodings_list = self.tok_encode(
|
| 664 |
-
requests, add_special_tokens=self.add_bos_token
|
| 665 |
-
)
|
| 666 |
-
else:
|
| 667 |
-
encodings_list = [None] * len(requests)
|
| 668 |
-
requests = [
|
| 669 |
-
(a, b, c) for a, b, c in zip(requests, all_gen_kwargs, encodings_list)
|
| 670 |
-
]
|
| 671 |
-
|
| 672 |
-
re_ord = Collator(
|
| 673 |
-
requests,
|
| 674 |
-
sort_fn=_collate_gen,
|
| 675 |
-
group_by="gen_kwargs",
|
| 676 |
-
)
|
| 677 |
-
chunked = re_ord.get_batched(
|
| 678 |
-
n=self._batch_size if self._concurrent <= 1 else 0, batch_fn=None
|
| 679 |
-
)
|
| 680 |
-
if not self.tokenized_requests:
|
| 681 |
-
eval_logger.info(
|
| 682 |
-
"Tokenized requests are disabled. Context + generation length is not checked."
|
| 683 |
-
)
|
| 684 |
-
if self._concurrent <= 1:
|
| 685 |
-
pbar = tqdm(desc="Requesting API", total=len(requests))
|
| 686 |
-
for chunk in chunked:
|
| 687 |
-
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
|
| 688 |
-
if self.tokenized_requests:
|
| 689 |
-
max_gen_toks = all_gen_kwargs[0].get(
|
| 690 |
-
"max_gen_toks", self._max_gen_toks
|
| 691 |
-
)
|
| 692 |
-
max_context_len = self.max_length - max_gen_toks
|
| 693 |
-
|
| 694 |
-
encodings_list = [x[-max_context_len:] for x in encodings_list]
|
| 695 |
-
|
| 696 |
-
if any(
|
| 697 |
-
len(x) + max_gen_toks > self.max_length for x in encodings_list
|
| 698 |
-
):
|
| 699 |
-
eval_logger.warning(
|
| 700 |
-
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks: ({max_gen_toks}). They were left truncated."
|
| 701 |
-
)
|
| 702 |
-
|
| 703 |
-
req = encodings_list if self.tokenized_requests else contexts
|
| 704 |
-
outputs = retry(
|
| 705 |
-
stop=stop_after_attempt(self.max_retries),
|
| 706 |
-
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
| 707 |
-
reraise=True,
|
| 708 |
-
)(self.model_call)(
|
| 709 |
-
messages=req,
|
| 710 |
-
generate=True,
|
| 711 |
-
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
|
| 712 |
-
)
|
| 713 |
-
for generated_text, context in zip(
|
| 714 |
-
self.parse_generations(
|
| 715 |
-
outputs=outputs,
|
| 716 |
-
contexts=contexts,
|
| 717 |
-
),
|
| 718 |
-
contexts,
|
| 719 |
-
):
|
| 720 |
-
if generated_text is not None:
|
| 721 |
-
res.append(generated_text)
|
| 722 |
-
|
| 723 |
-
# partial caching
|
| 724 |
-
if context is not None:
|
| 725 |
-
self.cache_hook.add_partial(
|
| 726 |
-
"generate_until",
|
| 727 |
-
(context, all_gen_kwargs[0]),
|
| 728 |
-
generated_text,
|
| 729 |
-
)
|
| 730 |
-
pbar.update(1)
|
| 731 |
-
else:
|
| 732 |
-
for chunk in chunked:
|
| 733 |
-
contexts, all_gen_kwargs, encodings_list = zip(*chunk)
|
| 734 |
-
if self.tokenized_requests:
|
| 735 |
-
max_gen_toks = all_gen_kwargs[0].get(
|
| 736 |
-
"max_gen_toks", self._max_gen_toks
|
| 737 |
-
)
|
| 738 |
-
max_context_len = self.max_length - max_gen_toks
|
| 739 |
-
|
| 740 |
-
encodings_list = [x[-max_context_len:] for x in encodings_list]
|
| 741 |
-
|
| 742 |
-
if any(
|
| 743 |
-
len(x) + max_gen_toks > self.max_length for x in encodings_list
|
| 744 |
-
):
|
| 745 |
-
eval_logger.warning(
|
| 746 |
-
f"Some contexts exceeded (max length: ({self.max_length}) - max_gen_toks ({max_gen_toks}). They were left truncated."
|
| 747 |
-
)
|
| 748 |
-
|
| 749 |
-
req = encodings_list if self.tokenized_requests else contexts
|
| 750 |
-
results = itertools.chain.from_iterable(
|
| 751 |
-
asyncio.run(
|
| 752 |
-
self.get_batched_requests(
|
| 753 |
-
req,
|
| 754 |
-
cache_keys=[(ctx, all_gen_kwargs[0]) for ctx in contexts],
|
| 755 |
-
generate=True,
|
| 756 |
-
gen_kwargs=copy.deepcopy(all_gen_kwargs[0]),
|
| 757 |
-
)
|
| 758 |
-
)
|
| 759 |
-
)
|
| 760 |
-
res.extend(results)
|
| 761 |
-
|
| 762 |
-
return re_ord.get_original(res)
|
| 763 |
-
|
| 764 |
-
def loglikelihood_rolling(
|
| 765 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 766 |
-
) -> List[float]:
|
| 767 |
-
loglikelihoods = []
|
| 768 |
-
|
| 769 |
-
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
|
| 770 |
-
rolling_token_windows = list(
|
| 771 |
-
map(
|
| 772 |
-
utils.make_disjoint_window,
|
| 773 |
-
utils.get_rolling_token_windows(
|
| 774 |
-
token_list=self.tok_encode(string),
|
| 775 |
-
prefix_token=self.prefix_token_id,
|
| 776 |
-
# max_seq_len - (1 for context)
|
| 777 |
-
max_seq_len=self.max_length - 1,
|
| 778 |
-
context_len=1,
|
| 779 |
-
),
|
| 780 |
-
)
|
| 781 |
-
)
|
| 782 |
-
|
| 783 |
-
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
| 784 |
-
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
| 785 |
-
|
| 786 |
-
string_nll = self._loglikelihood_tokens(
|
| 787 |
-
rolling_token_windows,
|
| 788 |
-
disable_tqdm=True,
|
| 789 |
-
)
|
| 790 |
-
|
| 791 |
-
# discard is_greedy
|
| 792 |
-
string_nll = [x[0] for x in string_nll]
|
| 793 |
-
|
| 794 |
-
string_nll = sum(string_nll)
|
| 795 |
-
loglikelihoods.append(string_nll)
|
| 796 |
-
|
| 797 |
-
# cache this loglikelihood_rolling request
|
| 798 |
-
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
|
| 799 |
-
return loglikelihoods
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/dummy.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
import random
|
| 2 |
-
|
| 3 |
-
from tqdm import tqdm
|
| 4 |
-
|
| 5 |
-
from lm_eval.api.model import LM
|
| 6 |
-
from lm_eval.api.registry import register_model
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
@register_model("dummy")
|
| 10 |
-
class DummyLM(LM):
|
| 11 |
-
def __init__(self) -> None:
|
| 12 |
-
super().__init__()
|
| 13 |
-
|
| 14 |
-
@classmethod
|
| 15 |
-
def create_from_arg_string(cls, arg_string, additional_config=None):
|
| 16 |
-
return cls()
|
| 17 |
-
|
| 18 |
-
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 19 |
-
res = []
|
| 20 |
-
|
| 21 |
-
for _ in tqdm(requests, disable=disable_tqdm):
|
| 22 |
-
res.append((-random.random(), False))
|
| 23 |
-
|
| 24 |
-
return res
|
| 25 |
-
|
| 26 |
-
def generate_until(self, requests, disable_tqdm: bool = False):
|
| 27 |
-
res = []
|
| 28 |
-
|
| 29 |
-
for request in tqdm(requests, disable=disable_tqdm):
|
| 30 |
-
res.append("lol")
|
| 31 |
-
assert request.arguments[0].strip() != ""
|
| 32 |
-
|
| 33 |
-
return res
|
| 34 |
-
|
| 35 |
-
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 36 |
-
res = []
|
| 37 |
-
|
| 38 |
-
for _ in tqdm(requests, disable=disable_tqdm):
|
| 39 |
-
res.append(-random.random())
|
| 40 |
-
|
| 41 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/gguf.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
import time
|
| 3 |
-
|
| 4 |
-
import requests
|
| 5 |
-
from requests.exceptions import RequestException
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
|
| 8 |
-
from lm_eval.api.model import LM
|
| 9 |
-
from lm_eval.api.registry import register_model
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
logger = logging.getLogger(__name__)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def get_result(logprobs, context_length):
|
| 16 |
-
is_greedy = True
|
| 17 |
-
offsets = logprobs["text_offset"]
|
| 18 |
-
tokens = logprobs["tokens"]
|
| 19 |
-
tokens_logprobs = logprobs["token_logprobs"]
|
| 20 |
-
|
| 21 |
-
idx = 0
|
| 22 |
-
while offsets[idx] < context_length:
|
| 23 |
-
idx += 1
|
| 24 |
-
continuation_logprobs = sum(tokens_logprobs[idx:-1])
|
| 25 |
-
for i in range(idx, len(tokens)):
|
| 26 |
-
token = tokens[i]
|
| 27 |
-
top_tokens = logprobs["top_logprobs"][i]
|
| 28 |
-
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
|
| 29 |
-
if top_token != token:
|
| 30 |
-
is_greedy = False
|
| 31 |
-
break
|
| 32 |
-
|
| 33 |
-
return continuation_logprobs, is_greedy
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
@register_model("gguf", "ggml")
|
| 37 |
-
class GGUFLM(LM):
|
| 38 |
-
def __init__(self, base_url=None, max_length=2048, **kwargs):
|
| 39 |
-
super().__init__()
|
| 40 |
-
self.base_url = base_url
|
| 41 |
-
assert self.base_url, "must pass `base_url` to use GGUF LM!"
|
| 42 |
-
self.logprobs = 10
|
| 43 |
-
self.temperature = 0.0
|
| 44 |
-
self.max_length = max_length
|
| 45 |
-
|
| 46 |
-
def gguf_completion(
|
| 47 |
-
self, context, continuation=None, stop=None, retries=3, delay=5, **kwargs
|
| 48 |
-
):
|
| 49 |
-
for _ in range(retries):
|
| 50 |
-
try:
|
| 51 |
-
prompt = context
|
| 52 |
-
request = {
|
| 53 |
-
"prompt": prompt,
|
| 54 |
-
"logprobs": self.logprobs,
|
| 55 |
-
"temperature": self.temperature,
|
| 56 |
-
}
|
| 57 |
-
if continuation:
|
| 58 |
-
prompt += continuation
|
| 59 |
-
request.update({"prompt": prompt, "max_tokens": 1, "echo": True})
|
| 60 |
-
if stop is not None:
|
| 61 |
-
request["stop"] = stop
|
| 62 |
-
response = requests.post(
|
| 63 |
-
f"{self.base_url}/v1/completions", json=request
|
| 64 |
-
)
|
| 65 |
-
response.raise_for_status()
|
| 66 |
-
return response.json()
|
| 67 |
-
except RequestException as e:
|
| 68 |
-
logger.error(f"RequestException: {e}")
|
| 69 |
-
time.sleep(delay) # wait before retrying
|
| 70 |
-
else:
|
| 71 |
-
raise RuntimeError(
|
| 72 |
-
f"Failed to get a valid response after {retries} retries."
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
def loglikelihood(self, requests, disable_tqdm: bool = False):
|
| 76 |
-
if not requests:
|
| 77 |
-
return []
|
| 78 |
-
res = []
|
| 79 |
-
for context, continuation in tqdm(
|
| 80 |
-
[req.args for req in requests], disable=disable_tqdm
|
| 81 |
-
):
|
| 82 |
-
response = self.gguf_completion(context=context, continuation=continuation)
|
| 83 |
-
if response and "choices" in response and response["choices"]:
|
| 84 |
-
choice = response["choices"][0]
|
| 85 |
-
logprobs = choice.get("logprobs")
|
| 86 |
-
if (
|
| 87 |
-
logprobs
|
| 88 |
-
and "token_logprobs" in logprobs
|
| 89 |
-
and logprobs["token_logprobs"]
|
| 90 |
-
):
|
| 91 |
-
logprob, is_greedy = get_result(logprobs, len(context))
|
| 92 |
-
res.append((logprob, is_greedy))
|
| 93 |
-
else:
|
| 94 |
-
logger.warning(
|
| 95 |
-
"Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
|
| 96 |
-
)
|
| 97 |
-
else:
|
| 98 |
-
logger.error(
|
| 99 |
-
f"Invalid response for loglikelihood. Response: {response}"
|
| 100 |
-
)
|
| 101 |
-
assert False
|
| 102 |
-
return res
|
| 103 |
-
|
| 104 |
-
def generate_until(self, requests, disable_tqdm: bool = False):
|
| 105 |
-
if not requests:
|
| 106 |
-
return []
|
| 107 |
-
|
| 108 |
-
res = []
|
| 109 |
-
for request in tqdm([req.args for req in requests], disable=disable_tqdm):
|
| 110 |
-
inp = request[0]
|
| 111 |
-
request_args = request[1]
|
| 112 |
-
until = request_args.get("until", ["</s>"])
|
| 113 |
-
response = self.gguf_completion(context=inp, stop=until)
|
| 114 |
-
if response and "choices" in response and response["choices"]:
|
| 115 |
-
choice = response["choices"][0]
|
| 116 |
-
if "text" in choice:
|
| 117 |
-
generated_text = choice["text"].strip()
|
| 118 |
-
res.append(generated_text)
|
| 119 |
-
else:
|
| 120 |
-
logger.error(
|
| 121 |
-
f"Invalid response for greedy_until. Response: {response}"
|
| 122 |
-
)
|
| 123 |
-
res.append(None) # Add default value in case of error
|
| 124 |
-
else:
|
| 125 |
-
logger.error(f"Invalid response for greedy_until. Response: {response}")
|
| 126 |
-
res.append(None) # Add default value in case of error
|
| 127 |
-
return res
|
| 128 |
-
|
| 129 |
-
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
|
| 130 |
-
raise NotImplementedError(
|
| 131 |
-
"loglikelihood_rolling not yet supported for GGUF models"
|
| 132 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/hf_audiolm.py
DELETED
|
@@ -1,307 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 3 |
-
|
| 4 |
-
import torch
|
| 5 |
-
import transformers
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
from transformers import BatchEncoding
|
| 8 |
-
|
| 9 |
-
from lm_eval.api.instance import Instance
|
| 10 |
-
from lm_eval.api.registry import register_model
|
| 11 |
-
from lm_eval.models.huggingface import HFLM
|
| 12 |
-
from lm_eval.models.utils import (
|
| 13 |
-
Collator,
|
| 14 |
-
replace_placeholders,
|
| 15 |
-
stop_sequences_criteria,
|
| 16 |
-
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
DEFAULT_AUDIO_PLACEHOLDERS = ["<audio>"]
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
@register_model("hf-audiolm-qwen")
|
| 23 |
-
class HFAUDIOLMQWEN(HFLM):
|
| 24 |
-
"""
|
| 25 |
-
An abstracted Hugging Face model class for Audio LM model like Qwen2-Audio.
|
| 26 |
-
"""
|
| 27 |
-
|
| 28 |
-
AUTO_MODEL_CLASS = transformers.Qwen2AudioForConditionalGeneration
|
| 29 |
-
MULTIMODAL = True # flag to indicate, for now, that this model type can run multimodal requests
|
| 30 |
-
|
| 31 |
-
def __init__(
|
| 32 |
-
self,
|
| 33 |
-
pretrained: Union[str, transformers.PreTrainedModel],
|
| 34 |
-
max_audios: Optional[int] = 5,
|
| 35 |
-
**kwargs,
|
| 36 |
-
):
|
| 37 |
-
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
|
| 38 |
-
# modify init behavior.
|
| 39 |
-
super().__init__(pretrained, **kwargs)
|
| 40 |
-
self.max_audios = max_audios
|
| 41 |
-
self.chat_applied: bool = False
|
| 42 |
-
|
| 43 |
-
def _create_tokenizer(
|
| 44 |
-
self,
|
| 45 |
-
pretrained: Union[str, transformers.PreTrainedModel],
|
| 46 |
-
tokenizer: Optional[
|
| 47 |
-
Union[
|
| 48 |
-
str,
|
| 49 |
-
transformers.ProcessorMixin,
|
| 50 |
-
]
|
| 51 |
-
],
|
| 52 |
-
revision: Optional[str] = "main",
|
| 53 |
-
trust_remote_code: Optional[bool] = False,
|
| 54 |
-
**kwargs,
|
| 55 |
-
) -> None:
|
| 56 |
-
"""
|
| 57 |
-
Helper method during initialization.
|
| 58 |
-
For the multimodal variant, we initialize not just
|
| 59 |
-
`self.tokenizer` but also `self.processor`.
|
| 60 |
-
"""
|
| 61 |
-
|
| 62 |
-
if tokenizer:
|
| 63 |
-
if isinstance(tokenizer, str):
|
| 64 |
-
return transformers.AutoTokenizer.from_pretrained(
|
| 65 |
-
tokenizer,
|
| 66 |
-
revision=revision,
|
| 67 |
-
trust_remote_code=trust_remote_code,
|
| 68 |
-
# use_fast=use_fast_tokenizer,
|
| 69 |
-
)
|
| 70 |
-
else:
|
| 71 |
-
assert isinstance(
|
| 72 |
-
tokenizer, transformers.ProcessorMixin
|
| 73 |
-
) # TODO: check this condition
|
| 74 |
-
return tokenizer
|
| 75 |
-
|
| 76 |
-
# Get tokenizer based on 'pretrained'
|
| 77 |
-
if isinstance(pretrained, str):
|
| 78 |
-
model_name = pretrained
|
| 79 |
-
else:
|
| 80 |
-
# get the HF hub name via accessor on model
|
| 81 |
-
model_name = self.model.name_or_path
|
| 82 |
-
|
| 83 |
-
self.processor = transformers.AutoProcessor.from_pretrained(
|
| 84 |
-
model_name,
|
| 85 |
-
revision=revision,
|
| 86 |
-
trust_remote_code=trust_remote_code,
|
| 87 |
-
# use_fast=use_fast_tokenizer,
|
| 88 |
-
)
|
| 89 |
-
self.tokenizer = self.processor.tokenizer
|
| 90 |
-
|
| 91 |
-
def apply_chat_template(
|
| 92 |
-
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 93 |
-
) -> str:
|
| 94 |
-
"""
|
| 95 |
-
Method to apply a chat template to a list of chat history between user and model.
|
| 96 |
-
"""
|
| 97 |
-
|
| 98 |
-
chat_templated = self.processor.apply_chat_template(
|
| 99 |
-
chat_history, tokenize=False, add_generation_prompt=add_generation_prompt
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
return chat_templated
|
| 103 |
-
|
| 104 |
-
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
|
| 105 |
-
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
|
| 106 |
-
do_sample = generation_kwargs.get("do_sample", None)
|
| 107 |
-
|
| 108 |
-
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
|
| 109 |
-
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
|
| 110 |
-
generation_kwargs["do_sample"] = do_sample = False
|
| 111 |
-
|
| 112 |
-
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
|
| 113 |
-
generation_kwargs.pop("temperature")
|
| 114 |
-
|
| 115 |
-
stopping_criteria = stop_sequences_criteria(
|
| 116 |
-
self.tokenizer,
|
| 117 |
-
stop,
|
| 118 |
-
inputs["input_ids"].shape[1],
|
| 119 |
-
inputs["input_ids"].shape[0],
|
| 120 |
-
)
|
| 121 |
-
return self.model.generate(
|
| 122 |
-
**inputs,
|
| 123 |
-
max_length=max_length,
|
| 124 |
-
stopping_criteria=stopping_criteria,
|
| 125 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 126 |
-
use_cache=True,
|
| 127 |
-
**generation_kwargs,
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
def tok_batch_multimodal_encode(
|
| 131 |
-
self,
|
| 132 |
-
strings: List[str], # note that input signature of this fn is different
|
| 133 |
-
audios: List[List],
|
| 134 |
-
padding_side: str = "left",
|
| 135 |
-
left_truncate_len: int = None,
|
| 136 |
-
truncation: bool = False,
|
| 137 |
-
) -> Union[
|
| 138 |
-
BatchEncoding, Dict[str, torch.Tensor]
|
| 139 |
-
]: # note that this return signature differs from HFLM tok_batch_encode.
|
| 140 |
-
# NOTE: here, we replace <audio> tags with our model's corresponding image_token string value.
|
| 141 |
-
def _replace_placeholder(placeholder, strings):
|
| 142 |
-
return [
|
| 143 |
-
replace_placeholders(
|
| 144 |
-
string,
|
| 145 |
-
placeholder,
|
| 146 |
-
"<|audio_bos|><|AUDIO|><|audio_eos|>",
|
| 147 |
-
self.max_audios,
|
| 148 |
-
)
|
| 149 |
-
for string in strings
|
| 150 |
-
]
|
| 151 |
-
|
| 152 |
-
if not self.chat_applied:
|
| 153 |
-
# TODO<baber>: This still keeps the whitespace in the image placeholder, which is not ideal.
|
| 154 |
-
for placeholder in DEFAULT_AUDIO_PLACEHOLDERS:
|
| 155 |
-
strings = _replace_placeholder(placeholder, strings)
|
| 156 |
-
|
| 157 |
-
encoding = self.processor(
|
| 158 |
-
audios=audios,
|
| 159 |
-
text=strings,
|
| 160 |
-
padding=True,
|
| 161 |
-
return_tensors="pt",
|
| 162 |
-
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention
|
| 166 |
-
self.device, self.model.dtype
|
| 167 |
-
) # TODO: This only casts the pixel values. Should they always be float16?
|
| 168 |
-
|
| 169 |
-
return encoding
|
| 170 |
-
|
| 171 |
-
def generate_until(
|
| 172 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 173 |
-
) -> List[str]:
|
| 174 |
-
res = []
|
| 175 |
-
|
| 176 |
-
def _collate(x):
|
| 177 |
-
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 178 |
-
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 179 |
-
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 180 |
-
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 181 |
-
# automatic adaptive batches much much easier to implement
|
| 182 |
-
# - any OOMs will happen right away rather than near the end
|
| 183 |
-
toks = self.tok_encode(x[0])
|
| 184 |
-
return -len(toks), x[0]
|
| 185 |
-
|
| 186 |
-
pbar = tqdm(
|
| 187 |
-
total=len(requests),
|
| 188 |
-
disable=(disable_tqdm or (self.rank != 0)),
|
| 189 |
-
desc="Running generate_until requests with text+audio input",
|
| 190 |
-
)
|
| 191 |
-
# TODO: port auto-batch sizing into this.
|
| 192 |
-
|
| 193 |
-
# we group requests by their generation_kwargs,
|
| 194 |
-
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
| 195 |
-
# in the same batch.
|
| 196 |
-
re_ords = Collator(
|
| 197 |
-
[reg.args for reg in requests],
|
| 198 |
-
_collate,
|
| 199 |
-
group_by="gen_kwargs",
|
| 200 |
-
group_fn=lambda x: x[1],
|
| 201 |
-
)
|
| 202 |
-
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
|
| 203 |
-
|
| 204 |
-
### Up to here: was identical to non-multimodal HFLM generate_until ###
|
| 205 |
-
|
| 206 |
-
for chunk in chunks:
|
| 207 |
-
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
|
| 208 |
-
|
| 209 |
-
audios = []
|
| 210 |
-
for audio_lst_dict in aux_arguments:
|
| 211 |
-
for audio in audio_lst_dict["audio"]:
|
| 212 |
-
audios.append(audio["array"])
|
| 213 |
-
|
| 214 |
-
if not isinstance(contexts, list):
|
| 215 |
-
contexts = list(
|
| 216 |
-
contexts
|
| 217 |
-
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
|
| 218 |
-
# TODO: could we upstream this workaround to HF?
|
| 219 |
-
### this part onward: same as HFLM ###
|
| 220 |
-
|
| 221 |
-
# we assume all gen kwargs in the batch are the same
|
| 222 |
-
# this is safe to assume because the `grouper` object ensures it.
|
| 223 |
-
gen_kwargs = all_gen_kwargs[0]
|
| 224 |
-
# unpack our keyword arguments.
|
| 225 |
-
until = None
|
| 226 |
-
if isinstance(gen_kwargs, dict):
|
| 227 |
-
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
| 228 |
-
if "until" in kwargs.keys():
|
| 229 |
-
until = kwargs.pop("until")
|
| 230 |
-
if isinstance(until, str):
|
| 231 |
-
until = [until]
|
| 232 |
-
elif not isinstance(until, list):
|
| 233 |
-
raise ValueError(
|
| 234 |
-
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
|
| 235 |
-
)
|
| 236 |
-
else:
|
| 237 |
-
raise ValueError(
|
| 238 |
-
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
|
| 239 |
-
)
|
| 240 |
-
# add EOS token to stop sequences
|
| 241 |
-
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
|
| 242 |
-
if not until:
|
| 243 |
-
until = [eos]
|
| 244 |
-
else:
|
| 245 |
-
until.append(eos)
|
| 246 |
-
if "max_gen_toks" in kwargs.keys():
|
| 247 |
-
max_gen_toks = kwargs.pop("max_gen_toks")
|
| 248 |
-
else:
|
| 249 |
-
max_gen_toks = self.max_gen_toks
|
| 250 |
-
|
| 251 |
-
## end stuff that's entirely copied verbatim from HFLM ###
|
| 252 |
-
|
| 253 |
-
max_ctx_len = self.max_length - max_gen_toks
|
| 254 |
-
|
| 255 |
-
inputs = self.tok_batch_multimodal_encode(
|
| 256 |
-
contexts,
|
| 257 |
-
audios,
|
| 258 |
-
left_truncate_len=max_ctx_len,
|
| 259 |
-
truncation=self.truncation,
|
| 260 |
-
)
|
| 261 |
-
|
| 262 |
-
context_enc = inputs["input_ids"]
|
| 263 |
-
|
| 264 |
-
if "max_length" not in kwargs:
|
| 265 |
-
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
| 266 |
-
inputs["input_ids"] = inputs["input_ids"].to("cuda")
|
| 267 |
-
inputs.input_ids = inputs.input_ids.to("cuda")
|
| 268 |
-
cont = self._model_multimodal_generate(inputs, stop=until, **kwargs)
|
| 269 |
-
|
| 270 |
-
del inputs
|
| 271 |
-
torch.cuda.empty_cache()
|
| 272 |
-
import gc
|
| 273 |
-
|
| 274 |
-
gc.collect()
|
| 275 |
-
|
| 276 |
-
### essentially same as HFLM beyond this line!
|
| 277 |
-
|
| 278 |
-
cont_toks_list = cont.tolist()
|
| 279 |
-
for cont_toks, context in zip(cont_toks_list, contexts):
|
| 280 |
-
# discard context + left-padding toks if using causal decoder-only VLM
|
| 281 |
-
cont_toks = cont_toks[context_enc.shape[1] :]
|
| 282 |
-
|
| 283 |
-
s = self.tok_decode(cont_toks)
|
| 284 |
-
|
| 285 |
-
res.append(s)
|
| 286 |
-
self.cache_hook.add_partial(
|
| 287 |
-
"generate_until", (context, gen_kwargs), s
|
| 288 |
-
) # TODO: cache key for multimodal input should be what?
|
| 289 |
-
pbar.update(1)
|
| 290 |
-
# reorder this group of results back to original unsorted form
|
| 291 |
-
res = re_ords.get_original(res)
|
| 292 |
-
|
| 293 |
-
pbar.close()
|
| 294 |
-
return res
|
| 295 |
-
|
| 296 |
-
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
| 297 |
-
raise NotImplementedError(
|
| 298 |
-
"model type `hf-audiolm` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
|
| 299 |
-
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
|
| 300 |
-
)
|
| 301 |
-
|
| 302 |
-
def loglikelihood(
|
| 303 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 304 |
-
) -> List[Tuple[float, bool]]:
|
| 305 |
-
raise NotImplementedError(
|
| 306 |
-
"'loglikelihood' requests for model type `hf-audiolm` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
|
| 307 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/hf_steered.py
DELETED
|
@@ -1,243 +0,0 @@
|
|
| 1 |
-
from contextlib import contextmanager
|
| 2 |
-
from functools import partial
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
from typing import Any, Callable, Generator, Optional, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
from peft.peft_model import PeftModel
|
| 8 |
-
from torch import Tensor, nn
|
| 9 |
-
from transformers import PreTrainedModel
|
| 10 |
-
|
| 11 |
-
from lm_eval.api.registry import register_model
|
| 12 |
-
from lm_eval.models.huggingface import HFLM
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@contextmanager
|
| 16 |
-
def steer(
|
| 17 |
-
model: Union[PreTrainedModel, PeftModel], hook_to_steer: dict[str, Callable]
|
| 18 |
-
) -> Generator[None, Any, None]:
|
| 19 |
-
"""
|
| 20 |
-
Context manager that temporarily hooks models and steers them.
|
| 21 |
-
|
| 22 |
-
Args:
|
| 23 |
-
model: The transformer model to hook
|
| 24 |
-
hook_to_steer: Dictionary mapping hookpoints to steering functions
|
| 25 |
-
|
| 26 |
-
Yields:
|
| 27 |
-
None
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
def create_hook(hookpoint: str):
|
| 31 |
-
def hook_fn(module: nn.Module, input: Any, output: Tensor):
|
| 32 |
-
# If output is a tuple (like in some transformer layers), take first element
|
| 33 |
-
if isinstance(output, tuple):
|
| 34 |
-
output = (hook_to_steer[hookpoint](output[0]), *output[1:]) # type: ignore
|
| 35 |
-
else:
|
| 36 |
-
output = hook_to_steer[hookpoint](output)
|
| 37 |
-
|
| 38 |
-
return output
|
| 39 |
-
|
| 40 |
-
return hook_fn
|
| 41 |
-
|
| 42 |
-
handles = []
|
| 43 |
-
hookpoints = list(hook_to_steer.keys())
|
| 44 |
-
|
| 45 |
-
for name, module in model.base_model.named_modules():
|
| 46 |
-
if name in hookpoints:
|
| 47 |
-
handle = module.register_forward_hook(create_hook(name))
|
| 48 |
-
handles.append(handle)
|
| 49 |
-
|
| 50 |
-
if len(handles) != len(hookpoints):
|
| 51 |
-
raise ValueError(f"Not all hookpoints could be resolved: {hookpoints}")
|
| 52 |
-
|
| 53 |
-
try:
|
| 54 |
-
yield None
|
| 55 |
-
finally:
|
| 56 |
-
for handle in handles:
|
| 57 |
-
handle.remove()
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
@register_model("steered")
|
| 61 |
-
class SteeredModel(HFLM):
|
| 62 |
-
hook_to_steer: dict[str, Callable]
|
| 63 |
-
|
| 64 |
-
def __init__(
|
| 65 |
-
self,
|
| 66 |
-
pretrained: str,
|
| 67 |
-
steer_path: str,
|
| 68 |
-
device: Optional[str] = None,
|
| 69 |
-
**kwargs,
|
| 70 |
-
):
|
| 71 |
-
"""
|
| 72 |
-
HFLM with a steered forward pass.
|
| 73 |
-
|
| 74 |
-
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
|
| 75 |
-
provide the path to a CSV file with the following columns (example rows are provided below):
|
| 76 |
-
|
| 77 |
-
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,sae_id,description,
|
| 78 |
-
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,
|
| 79 |
-
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,layer_20/width_16k/canonical,increase dogs,
|
| 80 |
-
|
| 81 |
-
To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format:
|
| 82 |
-
|
| 83 |
-
{
|
| 84 |
-
hookpoint: {
|
| 85 |
-
"steering_vector": <torch.Tensor>,
|
| 86 |
-
"steering_coefficient": <float>,
|
| 87 |
-
"action": <Literal["add", "clamp"]>,
|
| 88 |
-
"bias": <torch.Tensor | None>,
|
| 89 |
-
},
|
| 90 |
-
...
|
| 91 |
-
}
|
| 92 |
-
"""
|
| 93 |
-
super().__init__(pretrained=pretrained, device=device, **kwargs)
|
| 94 |
-
|
| 95 |
-
if steer_path.endswith(".pt") or steer_path.endswith(".pth"):
|
| 96 |
-
with open(steer_path, "rb") as f:
|
| 97 |
-
steer_config: dict[str, dict[str, Any]] = torch.load(
|
| 98 |
-
f, weights_only=True
|
| 99 |
-
)
|
| 100 |
-
elif steer_path.endswith(".csv"):
|
| 101 |
-
steer_config = self.derive_steer_config(steer_path)
|
| 102 |
-
else:
|
| 103 |
-
raise ValueError(f"Unknown steer file type: {steer_path}")
|
| 104 |
-
|
| 105 |
-
hook_to_steer = {}
|
| 106 |
-
for hookpoint, steer_info in steer_config.items():
|
| 107 |
-
action = steer_info["action"]
|
| 108 |
-
steering_coefficient = steer_info["steering_coefficient"]
|
| 109 |
-
steering_vector = (
|
| 110 |
-
steer_info["steering_vector"].to(self.device).to(self.model.dtype)
|
| 111 |
-
)
|
| 112 |
-
bias = (
|
| 113 |
-
steer_info["bias"].to(self.device).to(self.model.dtype)
|
| 114 |
-
if steer_info["bias"] is not None
|
| 115 |
-
else None
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
if action == "add":
|
| 119 |
-
# Steers the model by adding some multiple of a steering vector to all sequence positions.
|
| 120 |
-
hook_to_steer[hookpoint] = (
|
| 121 |
-
lambda acts: acts + steering_coefficient * steering_vector
|
| 122 |
-
)
|
| 123 |
-
elif action == "clamp":
|
| 124 |
-
hook_to_steer[hookpoint] = partial(
|
| 125 |
-
self.clamp,
|
| 126 |
-
steering_vector=steering_vector,
|
| 127 |
-
value=steering_coefficient,
|
| 128 |
-
bias=bias,
|
| 129 |
-
)
|
| 130 |
-
else:
|
| 131 |
-
raise ValueError(f"Unknown hook type: {action}")
|
| 132 |
-
|
| 133 |
-
self.hook_to_steer = hook_to_steer
|
| 134 |
-
|
| 135 |
-
@classmethod
|
| 136 |
-
def derive_steer_config(cls, steer_path: str):
|
| 137 |
-
"""Derive a dictionary of steering vectors from sparse model(/s) specified in a CSV file."""
|
| 138 |
-
import pandas as pd
|
| 139 |
-
|
| 140 |
-
df = pd.read_csv(steer_path)
|
| 141 |
-
steer_data: dict[str, dict[str, Any]] = {}
|
| 142 |
-
|
| 143 |
-
if any(df["loader"] == "sparsify"):
|
| 144 |
-
from sparsify import SparseCoder
|
| 145 |
-
if any(df["loader"] == "sae_lens"):
|
| 146 |
-
from sae_lens import SAE
|
| 147 |
-
|
| 148 |
-
sae_cache = {}
|
| 149 |
-
|
| 150 |
-
def load_from_sae_lens(sae_release: str, sae_id: str):
|
| 151 |
-
cache_key = (sae_release, sae_id)
|
| 152 |
-
if cache_key not in sae_cache:
|
| 153 |
-
sae_cache[cache_key] = SAE.from_pretrained(sae_release, sae_id)[0]
|
| 154 |
-
|
| 155 |
-
return sae_cache[cache_key]
|
| 156 |
-
|
| 157 |
-
for _, row in df.iterrows():
|
| 158 |
-
action = row.get("action", "add")
|
| 159 |
-
sparse_name = row["sparse_model"]
|
| 160 |
-
hookpoint = row["hookpoint"]
|
| 161 |
-
feature_index = int(row["feature_index"])
|
| 162 |
-
steering_coefficient = float(row["steering_coefficient"])
|
| 163 |
-
loader = row.get("loader", "sparsify")
|
| 164 |
-
|
| 165 |
-
if loader == "sparsify":
|
| 166 |
-
name_path = Path(sparse_name)
|
| 167 |
-
|
| 168 |
-
sparse_coder = (
|
| 169 |
-
SparseCoder.load_from_disk(name_path / hookpoint)
|
| 170 |
-
if name_path.exists()
|
| 171 |
-
else SparseCoder.load_from_hub(sparse_name, hookpoint)
|
| 172 |
-
)
|
| 173 |
-
assert sparse_coder.W_dec is not None
|
| 174 |
-
|
| 175 |
-
steering_vector = sparse_coder.W_dec[feature_index]
|
| 176 |
-
bias = sparse_coder.b_dec
|
| 177 |
-
|
| 178 |
-
elif loader == "sae_lens":
|
| 179 |
-
sparse_coder = load_from_sae_lens(
|
| 180 |
-
sae_release=sparse_name, sae_id=row["sae_id"]
|
| 181 |
-
)
|
| 182 |
-
steering_vector = sparse_coder.W_dec[feature_index]
|
| 183 |
-
bias = sparse_coder.b_dec
|
| 184 |
-
if hookpoint == "" or pd.isna(hookpoint):
|
| 185 |
-
hookpoint = sparse_coder.cfg.hook_name
|
| 186 |
-
else:
|
| 187 |
-
raise ValueError(f"Unknown loader: {loader}")
|
| 188 |
-
|
| 189 |
-
steer_data[hookpoint] = {
|
| 190 |
-
"action": action,
|
| 191 |
-
"steering_coefficient": steering_coefficient,
|
| 192 |
-
"steering_vector": steering_vector,
|
| 193 |
-
"bias": bias,
|
| 194 |
-
}
|
| 195 |
-
|
| 196 |
-
return steer_data
|
| 197 |
-
|
| 198 |
-
@classmethod
|
| 199 |
-
def clamp(
|
| 200 |
-
cls,
|
| 201 |
-
acts: Tensor,
|
| 202 |
-
steering_vector: Tensor,
|
| 203 |
-
value: float,
|
| 204 |
-
bias: Optional[Tensor] = None,
|
| 205 |
-
):
|
| 206 |
-
"""Clamps a direction of the activations to be the steering vector * the value.
|
| 207 |
-
|
| 208 |
-
Args:
|
| 209 |
-
acts (Tensor): The activations tensor to edit of shape [batch, pos, features]
|
| 210 |
-
steering_vector (Tensor): A direction to clamp of shape [features]
|
| 211 |
-
value (float): Value to clamp the direction to
|
| 212 |
-
bias (Tensor | None): Optional bias to add to the activations
|
| 213 |
-
|
| 214 |
-
Returns:
|
| 215 |
-
Tensor: The modified activations with the specified direction clamped
|
| 216 |
-
"""
|
| 217 |
-
|
| 218 |
-
if bias is not None:
|
| 219 |
-
acts = acts - bias
|
| 220 |
-
|
| 221 |
-
direction = steering_vector / torch.norm(steering_vector)
|
| 222 |
-
proj_magnitude = torch.sum(acts * direction, dim=-1, keepdim=True)
|
| 223 |
-
orthogonal_component = acts - proj_magnitude * direction
|
| 224 |
-
|
| 225 |
-
clamped = orthogonal_component + direction * value
|
| 226 |
-
|
| 227 |
-
if bias is not None:
|
| 228 |
-
return clamped + bias
|
| 229 |
-
|
| 230 |
-
return clamped
|
| 231 |
-
|
| 232 |
-
def forward(self, *args, **kwargs):
|
| 233 |
-
with torch.no_grad():
|
| 234 |
-
with steer(self.model, self.hook_to_steer):
|
| 235 |
-
return self.model.forward(*args, **kwargs)
|
| 236 |
-
|
| 237 |
-
def _model_call(self, *args, **kwargs):
|
| 238 |
-
with steer(self.model, self.hook_to_steer):
|
| 239 |
-
return super()._model_call(*args, **kwargs)
|
| 240 |
-
|
| 241 |
-
def _model_generate(self, *args, **kwargs):
|
| 242 |
-
with steer(self.model, self.hook_to_steer):
|
| 243 |
-
return super()._model_generate(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/hf_vlms.py
DELETED
|
@@ -1,757 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import logging
|
| 3 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
import transformers
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
from transformers import BatchEncoding
|
| 10 |
-
|
| 11 |
-
from lm_eval.api.instance import Instance
|
| 12 |
-
from lm_eval.api.registry import register_model
|
| 13 |
-
from lm_eval.models.huggingface import HFLM
|
| 14 |
-
from lm_eval.models.utils import (
|
| 15 |
-
Collator,
|
| 16 |
-
flatten_image_list,
|
| 17 |
-
handle_stop_sequences,
|
| 18 |
-
pad_and_concat,
|
| 19 |
-
replace_placeholders,
|
| 20 |
-
resize_image,
|
| 21 |
-
stop_sequences_criteria,
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
eval_logger = logging.getLogger(__name__)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
@register_model("hf-multimodal")
|
| 32 |
-
class HFMultimodalLM(HFLM):
|
| 33 |
-
"""
|
| 34 |
-
An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics.
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq
|
| 38 |
-
MULTIMODAL = True # flag to indicate, for now, that this model type can run multimodal requests
|
| 39 |
-
|
| 40 |
-
def __init__(
|
| 41 |
-
self,
|
| 42 |
-
pretrained: Union[str, transformers.PreTrainedModel],
|
| 43 |
-
image_token_id: Optional[int] = None,
|
| 44 |
-
image_string: Optional[str] = None,
|
| 45 |
-
interleave: bool = True,
|
| 46 |
-
# TODO: handle whitespace in image placeholder (replacement)
|
| 47 |
-
max_images: Optional[int] = 999,
|
| 48 |
-
convert_img_format=False,
|
| 49 |
-
# For image resizing
|
| 50 |
-
min_pixels: Optional[int] = None,
|
| 51 |
-
max_pixels: Optional[int] = None,
|
| 52 |
-
image_width: Optional[int] = None,
|
| 53 |
-
image_height: Optional[int] = None,
|
| 54 |
-
image_max_side: Optional[int] = None,
|
| 55 |
-
**kwargs,
|
| 56 |
-
):
|
| 57 |
-
self.image_width = image_width
|
| 58 |
-
self.image_height = image_height
|
| 59 |
-
self.image_max_side = image_max_side
|
| 60 |
-
if self.image_max_side and (self.image_width or self.image_height):
|
| 61 |
-
raise ValueError(
|
| 62 |
-
"Ambiguous config for image resize: you can not specify both "
|
| 63 |
-
"image_max_side and (image_width or image_height)"
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# init pixels before calling tokenizer creation to avoid errors
|
| 67 |
-
self.pixels = ({"min_pixels": min_pixels} if min_pixels else {}) | (
|
| 68 |
-
{"max_pixels": max_pixels} if max_pixels else {}
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
|
| 72 |
-
# modify init behavior.
|
| 73 |
-
super().__init__(pretrained, **kwargs)
|
| 74 |
-
|
| 75 |
-
assert self.batch_size != "auto", (
|
| 76 |
-
"Batch size 'auto' is not yet supported for hf-multimodal models."
|
| 77 |
-
)
|
| 78 |
-
self.chat_applied: bool = False
|
| 79 |
-
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
|
| 80 |
-
|
| 81 |
-
# HF AutoModelForVision2Seq models have an `image_token_id` value in their configs
|
| 82 |
-
# denoting the token which indicates a location where an image will be substituted in.
|
| 83 |
-
# This can take different string values across models, e.g. <image> for Idefics2 and <|image_pad|> for Qwen2-VL
|
| 84 |
-
self.interleave = interleave
|
| 85 |
-
self.max_images = max_images
|
| 86 |
-
self.rgb = convert_img_format
|
| 87 |
-
# WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors!
|
| 88 |
-
if not image_string:
|
| 89 |
-
self.image_token_id = (
|
| 90 |
-
int(image_token_id)
|
| 91 |
-
if image_token_id
|
| 92 |
-
else (
|
| 93 |
-
getattr(self.config, "image_token_id", None)
|
| 94 |
-
or getattr(self.config, "image_token_index", None)
|
| 95 |
-
)
|
| 96 |
-
)
|
| 97 |
-
assert self.image_token_id is not None, (
|
| 98 |
-
"Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
|
| 99 |
-
)
|
| 100 |
-
# get the string this token ID corresponds to
|
| 101 |
-
self.image_token = self.tok_decode(
|
| 102 |
-
[self.image_token_id], skip_special_tokens=False
|
| 103 |
-
)
|
| 104 |
-
if image_token_id is not None:
|
| 105 |
-
eval_logger.info(
|
| 106 |
-
f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
|
| 107 |
-
)
|
| 108 |
-
else:
|
| 109 |
-
eval_logger.info(
|
| 110 |
-
f"A non-default image_token string with string value image_string='{image_string}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
|
| 111 |
-
)
|
| 112 |
-
self.image_token = image_string
|
| 113 |
-
|
| 114 |
-
def _create_tokenizer(
|
| 115 |
-
self,
|
| 116 |
-
pretrained: Union[str, transformers.PreTrainedModel],
|
| 117 |
-
tokenizer: Optional[
|
| 118 |
-
Union[
|
| 119 |
-
str,
|
| 120 |
-
transformers.ProcessorMixin,
|
| 121 |
-
]
|
| 122 |
-
],
|
| 123 |
-
revision: Optional[str] = "main",
|
| 124 |
-
trust_remote_code: Optional[bool] = False,
|
| 125 |
-
**kwargs,
|
| 126 |
-
) -> None:
|
| 127 |
-
"""
|
| 128 |
-
Helper method during initialization.
|
| 129 |
-
|
| 130 |
-
For the multimodal variant, we initialize not just
|
| 131 |
-
`self.tokenizer` but also `self.processor`.
|
| 132 |
-
"""
|
| 133 |
-
|
| 134 |
-
if tokenizer:
|
| 135 |
-
if isinstance(tokenizer, str):
|
| 136 |
-
return transformers.AutoProcessor.from_pretrained(
|
| 137 |
-
tokenizer,
|
| 138 |
-
revision=revision,
|
| 139 |
-
trust_remote_code=trust_remote_code,
|
| 140 |
-
# use_fast=use_fast_tokenizer,
|
| 141 |
-
)
|
| 142 |
-
else:
|
| 143 |
-
assert isinstance(
|
| 144 |
-
tokenizer, transformers.ProcessorMixin
|
| 145 |
-
) # TODO: check this condition
|
| 146 |
-
return tokenizer
|
| 147 |
-
|
| 148 |
-
# Get tokenizer based on 'pretrained'
|
| 149 |
-
if isinstance(pretrained, str):
|
| 150 |
-
model_name = pretrained
|
| 151 |
-
else:
|
| 152 |
-
# get the HF hub name via accessor on model
|
| 153 |
-
model_name = self.model.name_or_path
|
| 154 |
-
|
| 155 |
-
self.processor = transformers.AutoProcessor.from_pretrained(
|
| 156 |
-
model_name,
|
| 157 |
-
revision=revision,
|
| 158 |
-
trust_remote_code=trust_remote_code,
|
| 159 |
-
**self.pixels,
|
| 160 |
-
# use_fast=use_fast_tokenizer,
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
self.tokenizer = self.processor.tokenizer
|
| 164 |
-
|
| 165 |
-
def tok_multimodal_encode(
|
| 166 |
-
self, string, images, left_truncate_len=None, add_special_tokens=None
|
| 167 |
-
):
|
| 168 |
-
"""Helper function which encodes an image + string combo using AutoProcessor"""
|
| 169 |
-
# We inherit special token kwarg setup from HFLM.tok_encode
|
| 170 |
-
# special_tokens_kwargs = {}
|
| 171 |
-
|
| 172 |
-
# by default for CausalLM - false or self.add_bos_token is set
|
| 173 |
-
# if add_special_tokens is None:
|
| 174 |
-
# special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token}
|
| 175 |
-
# otherwise the method explicitly defines the value
|
| 176 |
-
# else:
|
| 177 |
-
# special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
|
| 178 |
-
|
| 179 |
-
# encode text+images
|
| 180 |
-
# TODO: why does (Qwen2-VL) processor error when attempting to add special tokens to text?
|
| 181 |
-
encoding = self.processor(
|
| 182 |
-
text=string, images=images, return_tensors=None
|
| 183 |
-
) # , **special_tokens_kwargs)
|
| 184 |
-
|
| 185 |
-
# remove (and store) our tokenized text
|
| 186 |
-
text_encoding = encoding.pop("input_ids")
|
| 187 |
-
encoding.pop("attention_mask")
|
| 188 |
-
|
| 189 |
-
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 190 |
-
if left_truncate_len:
|
| 191 |
-
text_encoding = text_encoding[-left_truncate_len:]
|
| 192 |
-
|
| 193 |
-
return text_encoding, encoding # image_encoding is a dict
|
| 194 |
-
|
| 195 |
-
def _encode_multimodal_pair(self, context, continuation, images):
|
| 196 |
-
"""Helper function to perform the role of TemplateLM._encode_pair
|
| 197 |
-
Except allowing for image input to also be processed alongside `context`.
|
| 198 |
-
|
| 199 |
-
This method is a bit messy due to the need to defer conversion of image and text token input
|
| 200 |
-
into PyTorch tensors until the main inference loop.
|
| 201 |
-
"""
|
| 202 |
-
|
| 203 |
-
n_spaces = len(context) - len(context.rstrip())
|
| 204 |
-
if n_spaces > 0:
|
| 205 |
-
continuation = context[-n_spaces:] + continuation
|
| 206 |
-
context = context[:-n_spaces]
|
| 207 |
-
|
| 208 |
-
# TODO: replace default <image> placeholder with self.image_token, for contexts
|
| 209 |
-
|
| 210 |
-
whole_enc, image_enc = self.tok_multimodal_encode(
|
| 211 |
-
context + continuation, images
|
| 212 |
-
)
|
| 213 |
-
context_enc, _ = self.tok_multimodal_encode(context, images)
|
| 214 |
-
|
| 215 |
-
# tok_multimodal_encode returns List[List[int]] for tokenized text. Get rid of the batch dim
|
| 216 |
-
# since we only are encoding a single string.
|
| 217 |
-
# TODO: this is a bit hacky, it'd be nice to make this generally cleaner
|
| 218 |
-
whole_enc, context_enc = whole_enc[0], context_enc[0]
|
| 219 |
-
|
| 220 |
-
context_enc_len = len(context_enc)
|
| 221 |
-
continuation_enc = whole_enc[context_enc_len:]
|
| 222 |
-
|
| 223 |
-
return context_enc, continuation_enc, image_enc
|
| 224 |
-
|
| 225 |
-
def apply_chat_template(
|
| 226 |
-
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 227 |
-
) -> str:
|
| 228 |
-
self.chat_applied = True
|
| 229 |
-
if not self.interleave:
|
| 230 |
-
for content in chat_history:
|
| 231 |
-
c = []
|
| 232 |
-
text = content["content"]
|
| 233 |
-
|
| 234 |
-
# Count and remove image placeholders
|
| 235 |
-
image_count = min(
|
| 236 |
-
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
|
| 237 |
-
)
|
| 238 |
-
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")
|
| 239 |
-
|
| 240 |
-
# Add image entries
|
| 241 |
-
for _ in range(image_count):
|
| 242 |
-
c.append({"type": "image", "image": None})
|
| 243 |
-
|
| 244 |
-
# Add single text entry at the end
|
| 245 |
-
c.append({"type": "text", "text": text})
|
| 246 |
-
|
| 247 |
-
content["content"] = c
|
| 248 |
-
else:
|
| 249 |
-
for content in chat_history:
|
| 250 |
-
c = []
|
| 251 |
-
text = content["content"]
|
| 252 |
-
expected_image_count = min(
|
| 253 |
-
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
|
| 254 |
-
)
|
| 255 |
-
actual_image_count = 0
|
| 256 |
-
|
| 257 |
-
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)
|
| 258 |
-
|
| 259 |
-
for i, part in enumerate(text_parts):
|
| 260 |
-
# TODO: concatenate text parts (esp. if skipping images)?
|
| 261 |
-
if part: # Add non-empty text parts
|
| 262 |
-
c.append({"type": "text", "text": part})
|
| 263 |
-
if (
|
| 264 |
-
(i < len(text_parts) - 1) and i < self.max_images
|
| 265 |
-
): # Add image placeholder after each split except the last
|
| 266 |
-
c.append({"type": "image"})
|
| 267 |
-
actual_image_count += 1
|
| 268 |
-
|
| 269 |
-
content["content"] = c
|
| 270 |
-
|
| 271 |
-
if actual_image_count != expected_image_count:
|
| 272 |
-
raise ValueError(
|
| 273 |
-
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
return self.processor.apply_chat_template(
|
| 277 |
-
chat_history,
|
| 278 |
-
add_generation_prompt=add_generation_prompt,
|
| 279 |
-
continue_final_message=not add_generation_prompt,
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
|
| 283 |
-
if hasattr(self.processor, "apply_chat_template"):
|
| 284 |
-
_tokenizer = self.tokenizer
|
| 285 |
-
self.tokenizer = self.processor
|
| 286 |
-
|
| 287 |
-
selected_template = super().chat_template(chat_template)
|
| 288 |
-
|
| 289 |
-
self.tokenizer = _tokenizer
|
| 290 |
-
return selected_template
|
| 291 |
-
else:
|
| 292 |
-
return super().chat_template(chat_template)
|
| 293 |
-
|
| 294 |
-
def tok_batch_multimodal_encode(
|
| 295 |
-
self,
|
| 296 |
-
strings: List[str], # note that input signature of this fn is different
|
| 297 |
-
images: List[List], # TODO: images are pil.Image at the moment, update typehint
|
| 298 |
-
padding_side: str = "left",
|
| 299 |
-
left_truncate_len: int = None,
|
| 300 |
-
truncation: bool = False,
|
| 301 |
-
) -> Union[
|
| 302 |
-
BatchEncoding, Dict[str, torch.Tensor]
|
| 303 |
-
]: # note that this return signature differs from HFLM tok_batch_encode.
|
| 304 |
-
# NOTE: here, we replace <image> tags with our model's corresponding image_token string value.
|
| 305 |
-
if not self.chat_applied:
|
| 306 |
-
# TODO<baber>: This still keeps the whitespace in the image placeholder, which is not ideal.
|
| 307 |
-
strings = [
|
| 308 |
-
replace_placeholders(
|
| 309 |
-
string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
|
| 310 |
-
)
|
| 311 |
-
for string in strings
|
| 312 |
-
]
|
| 313 |
-
|
| 314 |
-
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
| 315 |
-
old_padding_side = self.tokenizer.padding_side
|
| 316 |
-
self.tokenizer.padding_side = padding_side
|
| 317 |
-
|
| 318 |
-
# add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
|
| 319 |
-
|
| 320 |
-
images = [img[: self.max_images] for img in images]
|
| 321 |
-
if self.rgb:
|
| 322 |
-
images = [[img.convert("RGB") for img in sublist] for sublist in images]
|
| 323 |
-
|
| 324 |
-
# certain models like llava expect a single-level image list even for bs>1, multi-image. TODO: port this over to loglikelihoods
|
| 325 |
-
if getattr(self.config, "model_type", "") == "llava":
|
| 326 |
-
images = flatten_image_list(images)
|
| 327 |
-
|
| 328 |
-
encoding = self.processor(
|
| 329 |
-
images=images,
|
| 330 |
-
text=strings,
|
| 331 |
-
truncation=truncation,
|
| 332 |
-
padding="longest",
|
| 333 |
-
return_tensors="pt",
|
| 334 |
-
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention
|
| 338 |
-
self.device, self.model.dtype
|
| 339 |
-
) # TODO: This only casts the pixel values. Should they always be float16?
|
| 340 |
-
if left_truncate_len:
|
| 341 |
-
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
| 342 |
-
encoding["attention_mask"] = encoding["attention_mask"][
|
| 343 |
-
:, -left_truncate_len:
|
| 344 |
-
]
|
| 345 |
-
self.tokenizer.padding_side = old_padding_side
|
| 346 |
-
|
| 347 |
-
return encoding
|
| 348 |
-
|
| 349 |
-
def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None):
|
| 350 |
-
"""
|
| 351 |
-
TODO: update docstring
|
| 352 |
-
"""
|
| 353 |
-
# note: imgs is a dict.
|
| 354 |
-
with torch.no_grad():
|
| 355 |
-
return self.model(inps, **imgs).logits
|
| 356 |
-
|
| 357 |
-
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
|
| 358 |
-
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
|
| 359 |
-
do_sample = generation_kwargs.get("do_sample", None)
|
| 360 |
-
|
| 361 |
-
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
|
| 362 |
-
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
|
| 363 |
-
generation_kwargs["do_sample"] = do_sample = False
|
| 364 |
-
|
| 365 |
-
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
|
| 366 |
-
generation_kwargs.pop("temperature")
|
| 367 |
-
|
| 368 |
-
stopping_criteria = stop_sequences_criteria(
|
| 369 |
-
self.tokenizer,
|
| 370 |
-
stop,
|
| 371 |
-
inputs["input_ids"].shape[1],
|
| 372 |
-
inputs["input_ids"].shape[0],
|
| 373 |
-
)
|
| 374 |
-
return self.model.generate(
|
| 375 |
-
**inputs,
|
| 376 |
-
max_length=max_length,
|
| 377 |
-
stopping_criteria=stopping_criteria,
|
| 378 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 379 |
-
use_cache=True,
|
| 380 |
-
**generation_kwargs,
|
| 381 |
-
)
|
| 382 |
-
|
| 383 |
-
def _batch_images(self, image_encs):
|
| 384 |
-
"""
|
| 385 |
-
Helper function: batch together image encodings across examples in a batch.
|
| 386 |
-
# TODO: for variable-sized images, this may break down.
|
| 387 |
-
"""
|
| 388 |
-
batched_imgs = {}
|
| 389 |
-
for key in image_encs[0].keys():
|
| 390 |
-
batched_imgs[key] = torch.cat(
|
| 391 |
-
[
|
| 392 |
-
torch.tensor(
|
| 393 |
-
image_enc[key], device=self.device, dtype=self.model.dtype
|
| 394 |
-
)
|
| 395 |
-
for image_enc in image_encs
|
| 396 |
-
],
|
| 397 |
-
dim=0,
|
| 398 |
-
)
|
| 399 |
-
return batched_imgs
|
| 400 |
-
|
| 401 |
-
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
|
| 402 |
-
if requests and len(requests[0].args) < 3:
|
| 403 |
-
# Fall back to non-multimodal generation.
|
| 404 |
-
return super().loglikelihood_rolling(requests=requests)
|
| 405 |
-
raise NotImplementedError(
|
| 406 |
-
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
|
| 407 |
-
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
|
| 408 |
-
)
|
| 409 |
-
|
| 410 |
-
def loglikelihood(
|
| 411 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 412 |
-
) -> List[Tuple[float, bool]]:
|
| 413 |
-
if requests and len(requests[0].args) < 3:
|
| 414 |
-
# Fall back to non-multimodal generation.
|
| 415 |
-
return super().loglikelihood(requests=requests, disable_tqdm=disable_tqdm)
|
| 416 |
-
raise NotImplementedError(
|
| 417 |
-
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
|
| 418 |
-
)
|
| 419 |
-
|
| 420 |
-
new_reqs = []
|
| 421 |
-
for context, continuation, aux_arguments in [req.args for req in requests]:
|
| 422 |
-
if context == "":
|
| 423 |
-
raise ValueError(
|
| 424 |
-
"Must get non-empty context for multimodal requests! You might be trying to run 'loglikelihood_rolling', which is not supported in the multimodal case."
|
| 425 |
-
)
|
| 426 |
-
else:
|
| 427 |
-
visuals = aux_arguments["visual"]
|
| 428 |
-
|
| 429 |
-
context_enc, continuation_enc, image_enc = self._encode_multimodal_pair(
|
| 430 |
-
context, continuation, visuals
|
| 431 |
-
)
|
| 432 |
-
# TODO: key to pick for caching images
|
| 433 |
-
new_reqs.append(
|
| 434 |
-
(
|
| 435 |
-
(context, continuation, visuals),
|
| 436 |
-
context_enc,
|
| 437 |
-
continuation_enc,
|
| 438 |
-
image_enc,
|
| 439 |
-
)
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
return self._multimodal_loglikelihood_tokens(
|
| 443 |
-
new_reqs, disable_tqdm=disable_tqdm
|
| 444 |
-
)
|
| 445 |
-
|
| 446 |
-
def _multimodal_loglikelihood_tokens(
|
| 447 |
-
self,
|
| 448 |
-
requests: List[
|
| 449 |
-
Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
|
| 450 |
-
], # TODO: update typehint to be correct
|
| 451 |
-
disable_tqdm: bool = False,
|
| 452 |
-
override_bs: int = None,
|
| 453 |
-
) -> List[Tuple[float, bool]]:
|
| 454 |
-
res = []
|
| 455 |
-
|
| 456 |
-
# TODO: **improve multimodal collation.** We currently ignore image size when ordering docs. ideally we'd take them into account
|
| 457 |
-
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 458 |
-
"""Defines the key for the sorted method"""
|
| 459 |
-
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 460 |
-
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 461 |
-
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 462 |
-
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 463 |
-
# automatic adaptive batches much much easier to implement
|
| 464 |
-
# - any OOMs will happen right away rather than near the end
|
| 465 |
-
toks = req[1] + req[2]
|
| 466 |
-
return -len(toks), tuple(toks)
|
| 467 |
-
|
| 468 |
-
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 469 |
-
"""Defines the key to group and lookup one-token continuations"""
|
| 470 |
-
# Use with group_by="contexts" (optional)"
|
| 471 |
-
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
|
| 472 |
-
# speeds up some multiple-choice tasks proportionally to the number of choices.
|
| 473 |
-
# groups requests by context+continuation[:-1] and infer on one request/group.
|
| 474 |
-
return req[-1] + req[-3] + req[-2][:-1]
|
| 475 |
-
|
| 476 |
-
re_ord = Collator(
|
| 477 |
-
requests,
|
| 478 |
-
sort_fn=_collate,
|
| 479 |
-
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
|
| 480 |
-
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
|
| 481 |
-
and self.logits_cache
|
| 482 |
-
else None,
|
| 483 |
-
group_fn=_lookup_one_token_cont,
|
| 484 |
-
)
|
| 485 |
-
|
| 486 |
-
# automatic (variable) batch size detection for vectorization
|
| 487 |
-
# pull longest context sample from request
|
| 488 |
-
n_reordered_requests = len(re_ord)
|
| 489 |
-
batch_size = (
|
| 490 |
-
self.batch_size
|
| 491 |
-
if self.batch_size != "auto"
|
| 492 |
-
else override_bs
|
| 493 |
-
if override_bs is not None
|
| 494 |
-
else 0
|
| 495 |
-
)
|
| 496 |
-
batch_fn = (
|
| 497 |
-
self._batch_scheduler
|
| 498 |
-
if self.batch_size == "auto"
|
| 499 |
-
and n_reordered_requests > 0
|
| 500 |
-
and not override_bs
|
| 501 |
-
else None
|
| 502 |
-
)
|
| 503 |
-
|
| 504 |
-
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 505 |
-
pbar = tqdm(
|
| 506 |
-
total=len(requests),
|
| 507 |
-
disable=(disable_tqdm or (self.rank != 0)),
|
| 508 |
-
desc="Running loglikelihood requests with text+image input",
|
| 509 |
-
)
|
| 510 |
-
for chunk in chunks:
|
| 511 |
-
imgs = []
|
| 512 |
-
inps = []
|
| 513 |
-
cont_toks_list = []
|
| 514 |
-
inplens = []
|
| 515 |
-
|
| 516 |
-
padding_len_inp = None
|
| 517 |
-
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
| 518 |
-
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
| 519 |
-
# again because vectorizing is annoying
|
| 520 |
-
|
| 521 |
-
for _, context_enc, continuation_enc, image_enc in chunk:
|
| 522 |
-
# sanity check
|
| 523 |
-
assert len(image_enc) > 0
|
| 524 |
-
assert len(context_enc) > 0
|
| 525 |
-
assert len(continuation_enc) > 0
|
| 526 |
-
assert len(continuation_enc) <= self.max_length
|
| 527 |
-
|
| 528 |
-
# how this all works (illustrated on a causal decoder-only setup):
|
| 529 |
-
# CTX CONT
|
| 530 |
-
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
| 531 |
-
# model \ \
|
| 532 |
-
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
| 533 |
-
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
| 534 |
-
|
| 535 |
-
# when too long to fit in context, truncate from the left
|
| 536 |
-
# TODO: assuming that we won't handle enc-dec Vision2Seq models. Is that a safe assumption?
|
| 537 |
-
inp = torch.tensor(
|
| 538 |
-
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
| 539 |
-
dtype=torch.long,
|
| 540 |
-
device=self.device,
|
| 541 |
-
)
|
| 542 |
-
(inplen,) = inp.shape
|
| 543 |
-
|
| 544 |
-
padding_len_inp = (
|
| 545 |
-
max(padding_len_inp, inplen)
|
| 546 |
-
if padding_len_inp is not None
|
| 547 |
-
else inplen
|
| 548 |
-
)
|
| 549 |
-
|
| 550 |
-
inps.append(inp) # [1, inp_length]
|
| 551 |
-
cont_toks_list.append(continuation_enc)
|
| 552 |
-
inplens.append(inplen)
|
| 553 |
-
|
| 554 |
-
imgs.append(image_enc)
|
| 555 |
-
|
| 556 |
-
# create encoder attn mask and batched conts, if seq2seq
|
| 557 |
-
call_kwargs = {}
|
| 558 |
-
batched_inps = pad_and_concat(
|
| 559 |
-
padding_len_inp, inps, padding_side="right"
|
| 560 |
-
) # [batch, padding_len_inp]
|
| 561 |
-
# batch our examples' image inputs together
|
| 562 |
-
batched_imgs = self._batch_images(
|
| 563 |
-
imgs
|
| 564 |
-
) # TODO: fix/test for bs>1 case with differently-sized imgs!
|
| 565 |
-
|
| 566 |
-
multi_logits = F.log_softmax(
|
| 567 |
-
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs),
|
| 568 |
-
dim=-1,
|
| 569 |
-
) # [batch, padding_length (inp or cont), vocab]
|
| 570 |
-
|
| 571 |
-
for (
|
| 572 |
-
request_str,
|
| 573 |
-
ctx_tokens,
|
| 574 |
-
_,
|
| 575 |
-
image_encs,
|
| 576 |
-
), logits, inplen, cont_toks in zip(
|
| 577 |
-
chunk, multi_logits, inplens, cont_toks_list
|
| 578 |
-
):
|
| 579 |
-
# Slice to original seq length
|
| 580 |
-
contlen = len(cont_toks)
|
| 581 |
-
# take only logits in the continuation
|
| 582 |
-
# (discard context toks if decoder-only ; discard right-padding)
|
| 583 |
-
# also discards + checks for "virtual tokens" in the causal LM's input window
|
| 584 |
-
# from prompt/prefix tuning tokens, if applicable
|
| 585 |
-
ctx_len = (
|
| 586 |
-
inplen + (logits.shape[0] - padding_len_inp)
|
| 587 |
-
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
|
| 588 |
-
else None
|
| 589 |
-
)
|
| 590 |
-
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
|
| 591 |
-
logits = logits.unsqueeze(0) # [1, seq, vocab]
|
| 592 |
-
|
| 593 |
-
# Check if per-token argmax is exactly equal to continuation
|
| 594 |
-
greedy_tokens = logits.argmax(dim=-1)
|
| 595 |
-
|
| 596 |
-
# check for one-token continuation cache hits.
|
| 597 |
-
# noop in case group_by != "contexts" or no cache hit and returns the
|
| 598 |
-
# original args. Otherwise, expands the logits batch dimension and yields each
|
| 599 |
-
# batch along with matching continuation tokens and prompt strings.
|
| 600 |
-
# logits -> [1, seq, vocab]
|
| 601 |
-
for request_str, cont_toks, logits in re_ord.get_cache(
|
| 602 |
-
req_str=request_str,
|
| 603 |
-
cxt_toks=ctx_tokens,
|
| 604 |
-
cont_toks=cont_toks,
|
| 605 |
-
logits=logits,
|
| 606 |
-
):
|
| 607 |
-
cont_toks = torch.tensor(
|
| 608 |
-
cont_toks, dtype=torch.long, device=self.device
|
| 609 |
-
).unsqueeze(0) # [1, seq]
|
| 610 |
-
max_equal = (greedy_tokens == cont_toks).all()
|
| 611 |
-
|
| 612 |
-
# Obtain log-probs at the corresponding continuation token indices
|
| 613 |
-
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
| 614 |
-
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
| 615 |
-
-1
|
| 616 |
-
) # [1, seq]
|
| 617 |
-
|
| 618 |
-
# Answer: (log prob, is-exact-match)
|
| 619 |
-
answer = (float(logits.sum()), bool(max_equal))
|
| 620 |
-
|
| 621 |
-
res.append(answer)
|
| 622 |
-
|
| 623 |
-
self.cache_hook.add_partial(
|
| 624 |
-
"loglikelihood", request_str, answer
|
| 625 |
-
) # TODO: choose convention for adding images into the cache key
|
| 626 |
-
pbar.update(1)
|
| 627 |
-
|
| 628 |
-
pbar.close()
|
| 629 |
-
|
| 630 |
-
return re_ord.get_original(res)
|
| 631 |
-
|
| 632 |
-
def generate_until(
|
| 633 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 634 |
-
) -> List[str]:
|
| 635 |
-
if requests and len(requests[0].args) < 3:
|
| 636 |
-
# Fall back to non-multimodal generation.
|
| 637 |
-
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
|
| 638 |
-
|
| 639 |
-
res = []
|
| 640 |
-
|
| 641 |
-
def _collate(x):
|
| 642 |
-
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 643 |
-
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 644 |
-
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 645 |
-
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 646 |
-
# automatic adaptive batches much much easier to implement
|
| 647 |
-
# - any OOMs will happen right away rather than near the end
|
| 648 |
-
toks = self.tok_encode(x[0])
|
| 649 |
-
return -len(toks), x[0]
|
| 650 |
-
|
| 651 |
-
pbar = tqdm(
|
| 652 |
-
total=len(requests),
|
| 653 |
-
disable=(disable_tqdm or (self.rank != 0)),
|
| 654 |
-
desc="Running generate_until requests with text+image input",
|
| 655 |
-
)
|
| 656 |
-
# TODO: port auto-batch sizing into this.
|
| 657 |
-
|
| 658 |
-
# we group requests by their generation_kwargs,
|
| 659 |
-
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
| 660 |
-
# in the same batch.
|
| 661 |
-
re_ords = Collator(
|
| 662 |
-
[reg.args for reg in requests],
|
| 663 |
-
_collate,
|
| 664 |
-
group_by="gen_kwargs",
|
| 665 |
-
group_fn=lambda x: x[1],
|
| 666 |
-
)
|
| 667 |
-
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
|
| 668 |
-
|
| 669 |
-
### Up to here: was identical to non-multimodal HFLM generate_until ###
|
| 670 |
-
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
|
| 671 |
-
for chunk in chunks:
|
| 672 |
-
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
|
| 673 |
-
|
| 674 |
-
visuals = [
|
| 675 |
-
[
|
| 676 |
-
resize_image(
|
| 677 |
-
img, self.image_width, self.image_height, self.image_max_side
|
| 678 |
-
)
|
| 679 |
-
for img in arg["visual"]
|
| 680 |
-
]
|
| 681 |
-
for arg in aux_arguments
|
| 682 |
-
]
|
| 683 |
-
|
| 684 |
-
if not isinstance(contexts, list):
|
| 685 |
-
contexts = list(
|
| 686 |
-
contexts
|
| 687 |
-
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
|
| 688 |
-
# TODO: could we upstream this workaround to HF?
|
| 689 |
-
### this part onward: same as HFLM ###
|
| 690 |
-
|
| 691 |
-
# we assume all gen kwargs in the batch are the same
|
| 692 |
-
# this is safe to assume because the `grouper` object ensures it.
|
| 693 |
-
gen_kwargs = all_gen_kwargs[0]
|
| 694 |
-
# unpack our keyword arguments.
|
| 695 |
-
if isinstance(gen_kwargs, dict):
|
| 696 |
-
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
| 697 |
-
# add EOS token to stop sequences
|
| 698 |
-
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
|
| 699 |
-
else:
|
| 700 |
-
raise ValueError(
|
| 701 |
-
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
|
| 702 |
-
)
|
| 703 |
-
if "max_gen_toks" in kwargs.keys():
|
| 704 |
-
max_gen_toks = kwargs.pop("max_gen_toks")
|
| 705 |
-
else:
|
| 706 |
-
max_gen_toks = self.max_gen_toks
|
| 707 |
-
|
| 708 |
-
### end stuff that's entirely copied verbatim from HFLM ###
|
| 709 |
-
|
| 710 |
-
max_ctx_len = self.max_length - max_gen_toks
|
| 711 |
-
|
| 712 |
-
inputs = self.tok_batch_multimodal_encode(
|
| 713 |
-
contexts,
|
| 714 |
-
visuals,
|
| 715 |
-
left_truncate_len=max_ctx_len,
|
| 716 |
-
truncation=self.truncation,
|
| 717 |
-
)
|
| 718 |
-
|
| 719 |
-
context_enc = inputs["input_ids"]
|
| 720 |
-
|
| 721 |
-
if "max_length" not in kwargs:
|
| 722 |
-
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
| 723 |
-
|
| 724 |
-
cont = self._model_multimodal_generate(inputs, stop=until, **kwargs)
|
| 725 |
-
|
| 726 |
-
del inputs
|
| 727 |
-
torch.cuda.empty_cache()
|
| 728 |
-
import gc
|
| 729 |
-
|
| 730 |
-
gc.collect()
|
| 731 |
-
|
| 732 |
-
### essentially same as HFLM beyond this line!
|
| 733 |
-
|
| 734 |
-
cont_toks_list = cont.tolist()
|
| 735 |
-
for cont_toks, context in zip(cont_toks_list, contexts):
|
| 736 |
-
# discard context + left-padding toks if using causal decoder-only VLM
|
| 737 |
-
cont_toks = cont_toks[context_enc.shape[1] :]
|
| 738 |
-
|
| 739 |
-
s = self.tok_decode(cont_toks)
|
| 740 |
-
|
| 741 |
-
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
|
| 742 |
-
for term in until:
|
| 743 |
-
if len(term) > 0:
|
| 744 |
-
# ignore '' separator,
|
| 745 |
-
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
| 746 |
-
s = s.split(term)[0]
|
| 747 |
-
|
| 748 |
-
res.append(s)
|
| 749 |
-
self.cache_hook.add_partial(
|
| 750 |
-
"generate_until", (context, gen_kwargs), s
|
| 751 |
-
) # TODO: cache key for multimodal input should be what?
|
| 752 |
-
pbar.update(1)
|
| 753 |
-
# reorder this group of results back to original unsorted form
|
| 754 |
-
res = re_ords.get_original(res)
|
| 755 |
-
|
| 756 |
-
pbar.close()
|
| 757 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/huggingface.py
DELETED
|
@@ -1,1480 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import logging
|
| 3 |
-
import os
|
| 4 |
-
from datetime import timedelta
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 7 |
-
|
| 8 |
-
import jinja2
|
| 9 |
-
import torch
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
import transformers
|
| 12 |
-
from accelerate import (
|
| 13 |
-
Accelerator,
|
| 14 |
-
InitProcessGroupKwargs,
|
| 15 |
-
find_executable_batch_size,
|
| 16 |
-
)
|
| 17 |
-
from accelerate.utils import get_max_memory
|
| 18 |
-
from huggingface_hub import HfApi
|
| 19 |
-
from packaging import version
|
| 20 |
-
from peft import PeftModel
|
| 21 |
-
from peft import __version__ as PEFT_VERSION
|
| 22 |
-
from tqdm import tqdm
|
| 23 |
-
from transformers.models.auto.modeling_auto import (
|
| 24 |
-
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
| 25 |
-
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
from lm_eval import utils
|
| 29 |
-
from lm_eval.api.instance import Instance
|
| 30 |
-
from lm_eval.api.model import TemplateLM
|
| 31 |
-
from lm_eval.api.registry import register_model
|
| 32 |
-
from lm_eval.models.utils import (
|
| 33 |
-
Collator,
|
| 34 |
-
clear_torch_cache,
|
| 35 |
-
configure_pad_token,
|
| 36 |
-
get_dtype,
|
| 37 |
-
handle_stop_sequences,
|
| 38 |
-
pad_and_concat,
|
| 39 |
-
stop_sequences_criteria,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
eval_logger = logging.getLogger(__name__)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
@register_model("hf-auto", "hf", "huggingface")
|
| 47 |
-
class HFLM(TemplateLM):
|
| 48 |
-
"""
|
| 49 |
-
An abstracted Huggingface model class. Enables usage with both models of
|
| 50 |
-
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
|
| 51 |
-
|
| 52 |
-
Supports data-parallel multi-GPU with HF Accelerate.
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
AUTO_MODEL_CLASS = None
|
| 56 |
-
_DEFAULT_MAX_LENGTH = 2048
|
| 57 |
-
|
| 58 |
-
def __init__(
|
| 59 |
-
self,
|
| 60 |
-
pretrained: Union[str, transformers.PreTrainedModel],
|
| 61 |
-
backend: Literal["default", "causal", "seq2seq"] = "default",
|
| 62 |
-
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
|
| 63 |
-
revision: Optional[str] = "main",
|
| 64 |
-
subfolder: Optional[str] = None,
|
| 65 |
-
tokenizer: Optional[
|
| 66 |
-
Union[
|
| 67 |
-
str,
|
| 68 |
-
transformers.PreTrainedTokenizer,
|
| 69 |
-
transformers.PreTrainedTokenizerFast,
|
| 70 |
-
]
|
| 71 |
-
] = None,
|
| 72 |
-
truncation: Optional[bool] = False,
|
| 73 |
-
logits_cache: bool = True,
|
| 74 |
-
max_length: Optional[int] = None,
|
| 75 |
-
device: Optional[str] = "cuda",
|
| 76 |
-
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 77 |
-
softmax_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 78 |
-
batch_size: Optional[Union[int, str]] = 1,
|
| 79 |
-
max_batch_size: Optional[int] = 64,
|
| 80 |
-
trust_remote_code: Optional[bool] = False,
|
| 81 |
-
use_fast_tokenizer: Optional[bool] = True,
|
| 82 |
-
add_bos_token: Optional[bool] = False,
|
| 83 |
-
prefix_token_id: Optional[int] = None,
|
| 84 |
-
# arguments used for splitting a model across GPUs naively.
|
| 85 |
-
# only used if `parallelize=True`.
|
| 86 |
-
parallelize: Optional[bool] = False,
|
| 87 |
-
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 88 |
-
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 89 |
-
offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
|
| 90 |
-
# PEFT, delta weights and quantization options
|
| 91 |
-
peft: Optional[str] = None,
|
| 92 |
-
delta: Optional[str] = None,
|
| 93 |
-
autogptq: Optional[Union[bool, str]] = False,
|
| 94 |
-
gptqmodel: Optional[bool] = False,
|
| 95 |
-
gguf_file: Optional[str] = None,
|
| 96 |
-
**kwargs,
|
| 97 |
-
) -> None:
|
| 98 |
-
super().__init__()
|
| 99 |
-
# optionally: take in an already-initialized transformers.PreTrainedModel
|
| 100 |
-
if not isinstance(pretrained, str):
|
| 101 |
-
eval_logger.warning(
|
| 102 |
-
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
|
| 103 |
-
)
|
| 104 |
-
assert not parallelize, (
|
| 105 |
-
"`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
|
| 106 |
-
)
|
| 107 |
-
self._model = pretrained
|
| 108 |
-
self._device = self._model.device
|
| 109 |
-
self._config = self._model.config
|
| 110 |
-
gpus = 0
|
| 111 |
-
|
| 112 |
-
else:
|
| 113 |
-
assert isinstance(device, str)
|
| 114 |
-
assert isinstance(pretrained, str)
|
| 115 |
-
assert isinstance(batch_size, (int, str))
|
| 116 |
-
|
| 117 |
-
gpus = torch.cuda.device_count()
|
| 118 |
-
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
| 119 |
-
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
|
| 120 |
-
if accelerator.num_processes > 1:
|
| 121 |
-
self.accelerator = accelerator
|
| 122 |
-
|
| 123 |
-
if "npu" in accelerator.device.type:
|
| 124 |
-
gpus = torch.npu.device_count()
|
| 125 |
-
|
| 126 |
-
# using one process with no model parallelism
|
| 127 |
-
if not (parallelize or accelerator.num_processes > 1):
|
| 128 |
-
# use user-passed device
|
| 129 |
-
device_list = set(
|
| 130 |
-
["cuda", "cpu"]
|
| 131 |
-
+ [f"cuda:{i}" for i in range(gpus)]
|
| 132 |
-
+ ["mps", "mps:0"]
|
| 133 |
-
+ [f"npu:{i}" for i in range(gpus)]
|
| 134 |
-
)
|
| 135 |
-
if device and device in device_list:
|
| 136 |
-
self._device = torch.device(device)
|
| 137 |
-
eval_logger.info(f"Using device '{device}'")
|
| 138 |
-
if device in ("mps", "mps:0") and version.parse(
|
| 139 |
-
torch.__version__
|
| 140 |
-
) < version.parse("2.1"):
|
| 141 |
-
raise RuntimeError(
|
| 142 |
-
f"mps requires torch >= 2.1. You have {torch.__version__}"
|
| 143 |
-
)
|
| 144 |
-
else:
|
| 145 |
-
eval_logger.info("Device not specified")
|
| 146 |
-
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
|
| 147 |
-
self._device = (
|
| 148 |
-
torch.device("cuda")
|
| 149 |
-
if torch.cuda.is_available()
|
| 150 |
-
else torch.device("cpu")
|
| 151 |
-
)
|
| 152 |
-
else: # Parallelism managed by accelerate
|
| 153 |
-
if device != "cuda":
|
| 154 |
-
eval_logger.info(
|
| 155 |
-
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
|
| 156 |
-
)
|
| 157 |
-
# TODO: include in warning that `load_in_8bit` etc. affect this too
|
| 158 |
-
self._device = (
|
| 159 |
-
self.accelerator.device
|
| 160 |
-
if hasattr(self, "accelerator")
|
| 161 |
-
else torch.device(device)
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
revision = str(revision) # cast to string if not already one
|
| 165 |
-
# TODO: update this to be less of a hack once subfolder is fixed in HF
|
| 166 |
-
revision = revision + ("/" + subfolder if subfolder is not None else "")
|
| 167 |
-
|
| 168 |
-
self._get_config(
|
| 169 |
-
pretrained,
|
| 170 |
-
revision=revision,
|
| 171 |
-
trust_remote_code=trust_remote_code,
|
| 172 |
-
gguf_file=gguf_file,
|
| 173 |
-
)
|
| 174 |
-
|
| 175 |
-
# determine which of 'causal' and 'seq2seq' backends to use for HF models
|
| 176 |
-
self._get_backend(
|
| 177 |
-
config=self.config, backend=backend, trust_remote_code=trust_remote_code
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
# load tokenizer so we know tokenizer vocabulary size before loading model and PEFT
|
| 181 |
-
self._create_tokenizer(
|
| 182 |
-
pretrained,
|
| 183 |
-
tokenizer,
|
| 184 |
-
revision=revision,
|
| 185 |
-
trust_remote_code=trust_remote_code,
|
| 186 |
-
use_fast_tokenizer=use_fast_tokenizer,
|
| 187 |
-
gguf_file=gguf_file,
|
| 188 |
-
add_bos_token=add_bos_token,
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
# if we passed `pretrained` as a string, initialize our model now
|
| 192 |
-
if isinstance(pretrained, str):
|
| 193 |
-
self._create_model(
|
| 194 |
-
pretrained=pretrained,
|
| 195 |
-
revision=revision,
|
| 196 |
-
dtype=dtype,
|
| 197 |
-
trust_remote_code=trust_remote_code,
|
| 198 |
-
parallelize=parallelize,
|
| 199 |
-
gpus=gpus,
|
| 200 |
-
max_memory_per_gpu=max_memory_per_gpu,
|
| 201 |
-
max_cpu_memory=max_cpu_memory,
|
| 202 |
-
offload_folder=offload_folder,
|
| 203 |
-
peft=peft,
|
| 204 |
-
delta=delta,
|
| 205 |
-
autogptq=autogptq,
|
| 206 |
-
gptqmodel=gptqmodel,
|
| 207 |
-
gguf_file=gguf_file,
|
| 208 |
-
quantization_config=getattr(self.config, "quantization_config", None),
|
| 209 |
-
**kwargs,
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
# access self._model through self.model property outside this method
|
| 213 |
-
if isinstance(self.model, torch.nn.Module):
|
| 214 |
-
self.model.eval()
|
| 215 |
-
self.model.tie_weights()
|
| 216 |
-
|
| 217 |
-
self.truncation = truncation
|
| 218 |
-
self.logits_cache = logits_cache
|
| 219 |
-
self.vocab_size = self.tokenizer.vocab_size
|
| 220 |
-
# select (or create) a pad token to use
|
| 221 |
-
self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config)
|
| 222 |
-
|
| 223 |
-
self.add_bos_token = add_bos_token
|
| 224 |
-
if "gemma" in getattr(self.config, "model_type", ""):
|
| 225 |
-
self.add_bos_token = True
|
| 226 |
-
eval_logger.info(
|
| 227 |
-
f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it."
|
| 228 |
-
)
|
| 229 |
-
|
| 230 |
-
self._max_length = max_length
|
| 231 |
-
self.pretrained = pretrained
|
| 232 |
-
self.delta = delta
|
| 233 |
-
self.peft = peft
|
| 234 |
-
self.revision = revision
|
| 235 |
-
self.batch_schedule = 1
|
| 236 |
-
self.batch_sizes = {}
|
| 237 |
-
self.max_batch_size = max_batch_size
|
| 238 |
-
self.softmax_dtype = (
|
| 239 |
-
get_dtype(softmax_dtype) if softmax_dtype is not None else None
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
if str(batch_size).startswith("auto"):
|
| 243 |
-
batch_size = batch_size.split(":")
|
| 244 |
-
self.batch_size_per_gpu = batch_size[0]
|
| 245 |
-
self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
|
| 246 |
-
else:
|
| 247 |
-
self.batch_size_per_gpu = int(batch_size)
|
| 248 |
-
|
| 249 |
-
if isinstance(pretrained, str):
|
| 250 |
-
if gpus >= 1 or str(self.device) == "mps":
|
| 251 |
-
# TODO: can remove this whole snippet except in the mps case, perhaps?
|
| 252 |
-
if not (parallelize or autogptq or hasattr(self, "accelerator")):
|
| 253 |
-
# place model onto device requested manually,
|
| 254 |
-
# if not using HF Accelerate or device_map
|
| 255 |
-
# or any other option that preloads model onto device
|
| 256 |
-
try:
|
| 257 |
-
self.model.to(self.device)
|
| 258 |
-
except ValueError:
|
| 259 |
-
eval_logger.debug(
|
| 260 |
-
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
|
| 261 |
-
)
|
| 262 |
-
# multigpu data-parallel support when launched with accelerate
|
| 263 |
-
if gpus > 1:
|
| 264 |
-
if accelerator.num_processes > 1:
|
| 265 |
-
if parallelize:
|
| 266 |
-
eval_logger.warning(
|
| 267 |
-
"You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
|
| 268 |
-
)
|
| 269 |
-
elif gpus > accelerator.num_processes:
|
| 270 |
-
eval_logger.warning(
|
| 271 |
-
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
|
| 272 |
-
"If you would like to use data parallelism, please launch the script "
|
| 273 |
-
"with 'accelerate launch *script*'. "
|
| 274 |
-
f"Current run will proceed with {accelerator.num_processes} devices."
|
| 275 |
-
)
|
| 276 |
-
if self.accelerator.is_local_main_process:
|
| 277 |
-
eval_logger.info(
|
| 278 |
-
f"Using {gpus} devices with data parallelism"
|
| 279 |
-
)
|
| 280 |
-
|
| 281 |
-
self._device = torch.device(f"{accelerator.device}")
|
| 282 |
-
self.accelerator = accelerator
|
| 283 |
-
|
| 284 |
-
self._rank = self.accelerator.local_process_index
|
| 285 |
-
self._world_size = self.accelerator.num_processes
|
| 286 |
-
else:
|
| 287 |
-
# if we aren't launching via accelerate, ditch
|
| 288 |
-
self._rank = 0
|
| 289 |
-
self._world_size = 1
|
| 290 |
-
else:
|
| 291 |
-
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
|
| 292 |
-
eval_logger.warning(
|
| 293 |
-
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
|
| 294 |
-
)
|
| 295 |
-
self._rank = 0
|
| 296 |
-
self._world_size = 1
|
| 297 |
-
|
| 298 |
-
self.custom_prefix_token_id = prefix_token_id
|
| 299 |
-
if prefix_token_id is not None:
|
| 300 |
-
eval_logger.info(
|
| 301 |
-
f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}"
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
def _get_accelerate_args(
|
| 305 |
-
self,
|
| 306 |
-
parallelize: Optional[bool] = None,
|
| 307 |
-
device_map: Optional[str] = "auto",
|
| 308 |
-
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 309 |
-
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 310 |
-
offload_folder: Optional[str] = "./offload",
|
| 311 |
-
gpus: Optional[int] = None,
|
| 312 |
-
) -> dict:
|
| 313 |
-
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
|
| 314 |
-
num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
|
| 315 |
-
num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes
|
| 316 |
-
if (
|
| 317 |
-
num_machines == 0
|
| 318 |
-
and hasattr(self, "accelerator")
|
| 319 |
-
and self.accelerator is not None
|
| 320 |
-
):
|
| 321 |
-
eval_logger.info(
|
| 322 |
-
"We are not in a distributed setting for accelerate. Setting model_parallel to False."
|
| 323 |
-
)
|
| 324 |
-
parallelize = False
|
| 325 |
-
|
| 326 |
-
if parallelize is None:
|
| 327 |
-
# If parallelism is unset by the user, we automatically assign model parallelism
|
| 328 |
-
# if enough extra GPUs are available
|
| 329 |
-
max_memory_all_gpus = get_max_memory()
|
| 330 |
-
# We just want gpu, not cpu, max memory
|
| 331 |
-
if "cpu" in max_memory_all_gpus:
|
| 332 |
-
del max_memory_all_gpus["cpu"]
|
| 333 |
-
parallelize = bool(num_local_processes < len(max_memory_all_gpus))
|
| 334 |
-
eval_logger.info(
|
| 335 |
-
f"Setting model parallel to {parallelize} since "
|
| 336 |
-
f"the number of local processes is {num_local_processes} "
|
| 337 |
-
f"and the number of GPUs is {len(max_memory_all_gpus)}"
|
| 338 |
-
)
|
| 339 |
-
|
| 340 |
-
args = {}
|
| 341 |
-
if parallelize: # Model parallelism will be used
|
| 342 |
-
max_memory = {}
|
| 343 |
-
if max_memory_per_gpu is not None: # Using the provided memory requirements
|
| 344 |
-
max_memory_per_gpu_map = {
|
| 345 |
-
device_idx: max_memory_per_gpu for device_idx in range(gpus)
|
| 346 |
-
}
|
| 347 |
-
else: # Estimating the possible memory requirements
|
| 348 |
-
max_memory_all_gpus = get_max_memory()
|
| 349 |
-
if "cpu" in max_memory_all_gpus:
|
| 350 |
-
del max_memory_all_gpus["cpu"]
|
| 351 |
-
if not hasattr(self, "accelerator"):
|
| 352 |
-
max_memory_per_gpu_map = {
|
| 353 |
-
k: v for k, v in max_memory_all_gpus.items()
|
| 354 |
-
}
|
| 355 |
-
else:
|
| 356 |
-
# use only 1 / num_processes of the GPUs if we are running under accelerate launch
|
| 357 |
-
max_memory_per_gpu_map = {
|
| 358 |
-
k: v
|
| 359 |
-
for k, v in max_memory_all_gpus.items()
|
| 360 |
-
if k % num_local_processes
|
| 361 |
-
== (self.accelerator.process_index % num_local_processes)
|
| 362 |
-
}
|
| 363 |
-
args["max_memory"] = max_memory_per_gpu_map
|
| 364 |
-
args["device_map"] = "auto" if device_map is None else device_map
|
| 365 |
-
eval_logger.info(
|
| 366 |
-
f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to {args.get('device_map')}"
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
-
if max_cpu_memory is not None:
|
| 370 |
-
max_memory["cpu"] = max_cpu_memory
|
| 371 |
-
|
| 372 |
-
args["offload_folder"] = offload_folder
|
| 373 |
-
elif (
|
| 374 |
-
device_map is None
|
| 375 |
-
): # No model parallelism, we use the default provided device for our model
|
| 376 |
-
if hasattr(self, "accelerator"):
|
| 377 |
-
device_map = {"": f"{self.accelerator.device}"}
|
| 378 |
-
else:
|
| 379 |
-
device_map = {"": str(self.device)}
|
| 380 |
-
args["max_memory"] = None
|
| 381 |
-
args["device_map"] = device_map
|
| 382 |
-
eval_logger.info(
|
| 383 |
-
f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}"
|
| 384 |
-
)
|
| 385 |
-
else:
|
| 386 |
-
args["max_memory"] = None
|
| 387 |
-
args["device_map"] = None
|
| 388 |
-
eval_logger.info("Model parallel was set to False.")
|
| 389 |
-
|
| 390 |
-
return args
|
| 391 |
-
|
| 392 |
-
@property
|
| 393 |
-
def config(self):
|
| 394 |
-
# return the associated transformers.AutoConfig for the given pretrained model.
|
| 395 |
-
return self._config
|
| 396 |
-
|
| 397 |
-
@property
|
| 398 |
-
def model(self):
|
| 399 |
-
# returns the model, unwrapping it if using Accelerate
|
| 400 |
-
if hasattr(self, "accelerator"):
|
| 401 |
-
return self.accelerator.unwrap_model(self._model)
|
| 402 |
-
else:
|
| 403 |
-
return self._model
|
| 404 |
-
|
| 405 |
-
@property
|
| 406 |
-
def eot_token_id(self):
|
| 407 |
-
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
| 408 |
-
return self.tokenizer.eos_token_id
|
| 409 |
-
|
| 410 |
-
@property
|
| 411 |
-
def prefix_token_id(self):
|
| 412 |
-
# it is used as prefix for loglikelihood
|
| 413 |
-
if self.custom_prefix_token_id is not None:
|
| 414 |
-
return self.custom_prefix_token_id
|
| 415 |
-
if self.tokenizer.bos_token_id is not None:
|
| 416 |
-
return self.tokenizer.bos_token_id
|
| 417 |
-
return self.tokenizer.eos_token_id
|
| 418 |
-
|
| 419 |
-
@property
|
| 420 |
-
def max_length(self):
|
| 421 |
-
if self._max_length: # if max length manually set, return it
|
| 422 |
-
return self._max_length
|
| 423 |
-
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
|
| 424 |
-
for attr in seqlen_config_attrs:
|
| 425 |
-
if hasattr(self.model.config, attr):
|
| 426 |
-
return getattr(self.model.config, attr)
|
| 427 |
-
if hasattr(self.tokenizer, "model_max_length"):
|
| 428 |
-
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
|
| 429 |
-
return self._DEFAULT_MAX_LENGTH
|
| 430 |
-
return self.tokenizer.model_max_length
|
| 431 |
-
return self._DEFAULT_MAX_LENGTH
|
| 432 |
-
|
| 433 |
-
@property
|
| 434 |
-
def max_gen_toks(self) -> int:
|
| 435 |
-
return 256
|
| 436 |
-
|
| 437 |
-
@property
|
| 438 |
-
def batch_size(self):
|
| 439 |
-
return self.batch_size_per_gpu
|
| 440 |
-
|
| 441 |
-
@property
|
| 442 |
-
def device(self):
|
| 443 |
-
return self._device
|
| 444 |
-
|
| 445 |
-
@property
|
| 446 |
-
def rank(self):
|
| 447 |
-
return self._rank
|
| 448 |
-
|
| 449 |
-
@property
|
| 450 |
-
def world_size(self):
|
| 451 |
-
return self._world_size
|
| 452 |
-
|
| 453 |
-
@property
|
| 454 |
-
def tokenizer_name(self) -> str:
|
| 455 |
-
return self.tokenizer.name_or_path.replace("/", "__")
|
| 456 |
-
|
| 457 |
-
def _get_backend(
|
| 458 |
-
self,
|
| 459 |
-
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
|
| 460 |
-
backend: Literal["default", "causal", "seq2seq"] = "default",
|
| 461 |
-
trust_remote_code: Optional[bool] = False,
|
| 462 |
-
) -> None:
|
| 463 |
-
"""
|
| 464 |
-
Helper method during initialization.
|
| 465 |
-
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
|
| 466 |
-
sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
|
| 467 |
-
|
| 468 |
-
**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
|
| 469 |
-
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
|
| 470 |
-
"""
|
| 471 |
-
|
| 472 |
-
assert backend in ["default", "causal", "seq2seq"]
|
| 473 |
-
|
| 474 |
-
if backend != "default":
|
| 475 |
-
# if we've settled on non-default backend, use that manually
|
| 476 |
-
if backend == "causal":
|
| 477 |
-
self.backend = backend
|
| 478 |
-
elif backend == "seq2seq":
|
| 479 |
-
self.backend = backend
|
| 480 |
-
eval_logger.info(
|
| 481 |
-
f"Overrode HF model backend type, and using type '{self.backend}'"
|
| 482 |
-
)
|
| 483 |
-
else:
|
| 484 |
-
# determine and use the default HF backend for this model, based on its config + metadata.
|
| 485 |
-
if (
|
| 486 |
-
getattr(config, "model_type")
|
| 487 |
-
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
| 488 |
-
):
|
| 489 |
-
# first check if model type is listed under seq2seq models, since some
|
| 490 |
-
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
|
| 491 |
-
# these special cases should be treated as seq2seq models.
|
| 492 |
-
self.backend = "seq2seq"
|
| 493 |
-
eval_logger.debug(f"Using model type '{self.backend}'")
|
| 494 |
-
elif (
|
| 495 |
-
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 496 |
-
):
|
| 497 |
-
self.backend = "causal"
|
| 498 |
-
eval_logger.debug(f"Using model type '{self.backend}'")
|
| 499 |
-
else:
|
| 500 |
-
if not trust_remote_code:
|
| 501 |
-
eval_logger.warning(
|
| 502 |
-
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
|
| 503 |
-
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
|
| 504 |
-
"Setting backend to causal"
|
| 505 |
-
)
|
| 506 |
-
# if model type is neither in HF transformers causal or seq2seq model registries
|
| 507 |
-
# then we default to assuming AutoModelForCausalLM
|
| 508 |
-
self.backend = "causal"
|
| 509 |
-
eval_logger.info(
|
| 510 |
-
f"Model type cannot be determined. Using default model type '{self.backend}'"
|
| 511 |
-
)
|
| 512 |
-
|
| 513 |
-
if self.AUTO_MODEL_CLASS is None:
|
| 514 |
-
if self.backend == "causal":
|
| 515 |
-
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
|
| 516 |
-
elif self.backend == "seq2seq":
|
| 517 |
-
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
|
| 518 |
-
|
| 519 |
-
def _get_config(
|
| 520 |
-
self,
|
| 521 |
-
pretrained: str,
|
| 522 |
-
revision: str = "main",
|
| 523 |
-
trust_remote_code: bool = False,
|
| 524 |
-
gguf_file: Optional[str] = None,
|
| 525 |
-
) -> None:
|
| 526 |
-
"""Return the model config for HuggingFace models"""
|
| 527 |
-
self._config = transformers.AutoConfig.from_pretrained(
|
| 528 |
-
pretrained,
|
| 529 |
-
revision=revision,
|
| 530 |
-
trust_remote_code=trust_remote_code,
|
| 531 |
-
gguf_file=gguf_file,
|
| 532 |
-
)
|
| 533 |
-
|
| 534 |
-
def _create_model(
|
| 535 |
-
self,
|
| 536 |
-
pretrained: str,
|
| 537 |
-
revision: Optional[str] = "main",
|
| 538 |
-
dtype: Optional[Union[str, torch.dtype]] = "auto",
|
| 539 |
-
trust_remote_code: Optional[bool] = False,
|
| 540 |
-
# arguments used for splitting a model across GPUs naively.
|
| 541 |
-
# only used if `parallelize=True`.
|
| 542 |
-
# (accelerate naive PP (device_map) options)
|
| 543 |
-
parallelize: Optional[bool] = False,
|
| 544 |
-
gpus: Optional[int] = None,
|
| 545 |
-
max_memory_per_gpu: Optional[Union[int, str]] = None,
|
| 546 |
-
max_cpu_memory: Optional[Union[int, str]] = None,
|
| 547 |
-
offload_folder: Optional[str] = "./offload",
|
| 548 |
-
# PEFT, delta weights and quantization options
|
| 549 |
-
peft: Optional[str] = None,
|
| 550 |
-
delta: Optional[str] = None,
|
| 551 |
-
autogptq: Optional[Union[bool, str]] = False,
|
| 552 |
-
gptqmodel: Optional[bool] = False,
|
| 553 |
-
gguf_file: Optional[str] = None,
|
| 554 |
-
quantization_config: Optional[Dict[str, Any]] = None,
|
| 555 |
-
**kwargs,
|
| 556 |
-
) -> None:
|
| 557 |
-
"""
|
| 558 |
-
Initializes an HF or HF-compatible PreTrainedModel from scratch
|
| 559 |
-
inside HFLM, using the kwargs passed into self.__init__().
|
| 560 |
-
|
| 561 |
-
Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
|
| 562 |
-
|
| 563 |
-
For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
|
| 564 |
-
(such as PyTorch models that are nearly, but not quite, fully mirroring
|
| 565 |
-
HF's public interface relied on in this HFLM class)
|
| 566 |
-
please consider subclassing HFLM and overriding this and other methods as needed.
|
| 567 |
-
"""
|
| 568 |
-
|
| 569 |
-
model_kwargs = kwargs if kwargs else {}
|
| 570 |
-
|
| 571 |
-
model_kwargs.update(
|
| 572 |
-
self._get_accelerate_args(
|
| 573 |
-
parallelize=parallelize,
|
| 574 |
-
device_map=kwargs.get("device_map", None),
|
| 575 |
-
max_memory_per_gpu=max_memory_per_gpu,
|
| 576 |
-
max_cpu_memory=max_cpu_memory,
|
| 577 |
-
offload_folder=offload_folder,
|
| 578 |
-
gpus=gpus,
|
| 579 |
-
)
|
| 580 |
-
)
|
| 581 |
-
|
| 582 |
-
if not autogptq and not gptqmodel:
|
| 583 |
-
if model_kwargs.get("load_in_4bit", None):
|
| 584 |
-
assert transformers.__version__ >= "4.30.0", (
|
| 585 |
-
"load_in_4bit requires transformers >= 4.30.0"
|
| 586 |
-
)
|
| 587 |
-
if transformers.__version__ >= "4.30.0":
|
| 588 |
-
if model_kwargs.get("load_in_4bit", None):
|
| 589 |
-
if model_kwargs.get("bnb_4bit_compute_dtype", None):
|
| 590 |
-
model_kwargs["bnb_4bit_compute_dtype"] = get_dtype(
|
| 591 |
-
model_kwargs["bnb_4bit_compute_dtype"]
|
| 592 |
-
)
|
| 593 |
-
|
| 594 |
-
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
|
| 595 |
-
pretrained,
|
| 596 |
-
revision=revision,
|
| 597 |
-
torch_dtype=get_dtype(dtype),
|
| 598 |
-
trust_remote_code=trust_remote_code,
|
| 599 |
-
gguf_file=gguf_file,
|
| 600 |
-
quantization_config=quantization_config,
|
| 601 |
-
**model_kwargs,
|
| 602 |
-
)
|
| 603 |
-
else:
|
| 604 |
-
if autogptq and gptqmodel:
|
| 605 |
-
raise ValueError(
|
| 606 |
-
"Cannot use both 'autogptq' and 'gptqmodel' options at the same time."
|
| 607 |
-
)
|
| 608 |
-
|
| 609 |
-
if autogptq:
|
| 610 |
-
try:
|
| 611 |
-
from auto_gptq import AutoGPTQForCausalLM
|
| 612 |
-
except ModuleNotFoundError as exception:
|
| 613 |
-
raise type(exception)(
|
| 614 |
-
"Tried to load auto_gptq, but auto-gptq is not installed ",
|
| 615 |
-
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
|
| 616 |
-
)
|
| 617 |
-
|
| 618 |
-
self._model = AutoGPTQForCausalLM.from_quantized(
|
| 619 |
-
pretrained,
|
| 620 |
-
trust_remote_code=trust_remote_code,
|
| 621 |
-
model_basename=None if autogptq is True else Path(autogptq).stem,
|
| 622 |
-
use_safetensors=True
|
| 623 |
-
if autogptq is True
|
| 624 |
-
else autogptq.endswith(".safetensors"),
|
| 625 |
-
**model_kwargs,
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
if gptqmodel:
|
| 629 |
-
try:
|
| 630 |
-
from gptqmodel import GPTQModel
|
| 631 |
-
except ModuleNotFoundError as exception:
|
| 632 |
-
raise type(exception)(
|
| 633 |
-
"Tried to load gptqmodel, but gptqmodel is not installed ",
|
| 634 |
-
"please install gptqmodel via `pip install gptqmodel --no-build-isolation` or `pip install lm-eval[gptqmodel] --no-build-isolation`",
|
| 635 |
-
)
|
| 636 |
-
|
| 637 |
-
self._model = GPTQModel.from_quantized(
|
| 638 |
-
pretrained, trust_remote_code=trust_remote_code, **model_kwargs
|
| 639 |
-
)
|
| 640 |
-
|
| 641 |
-
if peft and delta:
|
| 642 |
-
raise ValueError(
|
| 643 |
-
"Cannot use both 'peft' and 'delta' options at the same time."
|
| 644 |
-
)
|
| 645 |
-
|
| 646 |
-
if peft:
|
| 647 |
-
if model_kwargs.get("load_in_4bit", None):
|
| 648 |
-
if version.parse(PEFT_VERSION) < version.parse("0.4.0"):
|
| 649 |
-
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
|
| 650 |
-
if self._model.config.vocab_size != len(self.tokenizer):
|
| 651 |
-
# resize model for LoRAs with added tokens
|
| 652 |
-
eval_logger.info(
|
| 653 |
-
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
|
| 654 |
-
)
|
| 655 |
-
self._model.resize_token_embeddings(len(self.tokenizer))
|
| 656 |
-
self._model = PeftModel.from_pretrained(
|
| 657 |
-
self._model, peft, revision=revision
|
| 658 |
-
)
|
| 659 |
-
elif delta:
|
| 660 |
-
if autogptq:
|
| 661 |
-
eval_logger.warning(
|
| 662 |
-
"Delta weights might trigger unexpected behavior when used with AutoGPTQ."
|
| 663 |
-
)
|
| 664 |
-
_model_delta = self.AUTO_MODEL_CLASS.from_pretrained(
|
| 665 |
-
delta,
|
| 666 |
-
revision=revision,
|
| 667 |
-
torch_dtype=get_dtype(dtype),
|
| 668 |
-
trust_remote_code=trust_remote_code,
|
| 669 |
-
**model_kwargs,
|
| 670 |
-
)
|
| 671 |
-
for name, param in self._model.state_dict().items():
|
| 672 |
-
try:
|
| 673 |
-
param.data += _model_delta.state_dict()[name]
|
| 674 |
-
except KeyError:
|
| 675 |
-
raise KeyError(f"Delta model is missing weights for layer: {name}")
|
| 676 |
-
except Exception as e:
|
| 677 |
-
raise RuntimeError(
|
| 678 |
-
f"Failed to add delta weights to layer {name}. Error: {e}"
|
| 679 |
-
)
|
| 680 |
-
|
| 681 |
-
del _model_delta
|
| 682 |
-
|
| 683 |
-
return None
|
| 684 |
-
|
| 685 |
-
def _create_tokenizer(
|
| 686 |
-
self,
|
| 687 |
-
pretrained: Union[str, transformers.PreTrainedModel],
|
| 688 |
-
tokenizer: Optional[
|
| 689 |
-
Union[
|
| 690 |
-
str,
|
| 691 |
-
transformers.PreTrainedTokenizer,
|
| 692 |
-
transformers.PreTrainedTokenizerFast,
|
| 693 |
-
]
|
| 694 |
-
],
|
| 695 |
-
revision: Optional[str] = "main",
|
| 696 |
-
trust_remote_code: Optional[bool] = False,
|
| 697 |
-
use_fast_tokenizer: Optional[bool] = True,
|
| 698 |
-
gguf_file: Optional[str] = None,
|
| 699 |
-
add_bos_token: Optional[bool] = False,
|
| 700 |
-
) -> None:
|
| 701 |
-
"""
|
| 702 |
-
Helper method during initialization.
|
| 703 |
-
|
| 704 |
-
Create a tokenizer object corresponding to the correct
|
| 705 |
-
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
|
| 706 |
-
"""
|
| 707 |
-
kwargs = {
|
| 708 |
-
"revision": revision,
|
| 709 |
-
"trust_remote_code": trust_remote_code,
|
| 710 |
-
}
|
| 711 |
-
|
| 712 |
-
# gguf format embeds tokenizer and is not compatible with hf tokenizer `use_fast` param
|
| 713 |
-
if gguf_file is not None:
|
| 714 |
-
kwargs["gguf_file"] = gguf_file
|
| 715 |
-
else:
|
| 716 |
-
kwargs["use_fast"] = use_fast_tokenizer
|
| 717 |
-
|
| 718 |
-
if add_bos_token:
|
| 719 |
-
kwargs["add_bos_token"] = True
|
| 720 |
-
|
| 721 |
-
if tokenizer:
|
| 722 |
-
if isinstance(tokenizer, str):
|
| 723 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 724 |
-
tokenizer, **kwargs
|
| 725 |
-
)
|
| 726 |
-
else:
|
| 727 |
-
assert isinstance(
|
| 728 |
-
tokenizer, transformers.PreTrainedTokenizer
|
| 729 |
-
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
|
| 730 |
-
self.tokenizer = tokenizer
|
| 731 |
-
else:
|
| 732 |
-
# Get tokenizer based on 'pretrained'
|
| 733 |
-
if isinstance(pretrained, str):
|
| 734 |
-
model_name = pretrained
|
| 735 |
-
else:
|
| 736 |
-
# get the HF hub name via accessor on model
|
| 737 |
-
model_name = self.model.name_or_path
|
| 738 |
-
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 739 |
-
model_name, **kwargs
|
| 740 |
-
)
|
| 741 |
-
return None
|
| 742 |
-
|
| 743 |
-
def _detect_batch_size(self, requests=None, pos: int = 0):
|
| 744 |
-
if requests:
|
| 745 |
-
_, context_enc, continuation_enc = requests[pos]
|
| 746 |
-
max_length = len(
|
| 747 |
-
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
|
| 748 |
-
)
|
| 749 |
-
max_context_enc = len(context_enc[-(self.max_length + 1) :])
|
| 750 |
-
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
|
| 751 |
-
else:
|
| 752 |
-
max_length = self.max_length
|
| 753 |
-
max_context_enc = max_length
|
| 754 |
-
max_cont_enc = max_length
|
| 755 |
-
|
| 756 |
-
# if OOM, then halves batch_size and tries again
|
| 757 |
-
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
|
| 758 |
-
def forward_batch(batch_size):
|
| 759 |
-
if self.backend == "seq2seq":
|
| 760 |
-
length = max(max_context_enc, max_cont_enc)
|
| 761 |
-
batched_conts = torch.ones(
|
| 762 |
-
(batch_size, length), device=self.device
|
| 763 |
-
).long()
|
| 764 |
-
test_batch = torch.ones((batch_size, length), device=self.device).long()
|
| 765 |
-
call_kwargs = {
|
| 766 |
-
"attn_mask": test_batch,
|
| 767 |
-
"labels": batched_conts,
|
| 768 |
-
}
|
| 769 |
-
else:
|
| 770 |
-
call_kwargs = {}
|
| 771 |
-
test_batch = torch.ones(
|
| 772 |
-
(batch_size, max_length), device=self.device
|
| 773 |
-
).long()
|
| 774 |
-
for _ in range(5):
|
| 775 |
-
out = F.log_softmax( # noqa: F841
|
| 776 |
-
self._model_call(test_batch, **call_kwargs),
|
| 777 |
-
dim=-1,
|
| 778 |
-
dtype=self.softmax_dtype,
|
| 779 |
-
)
|
| 780 |
-
|
| 781 |
-
return batch_size
|
| 782 |
-
|
| 783 |
-
try:
|
| 784 |
-
batch_size = forward_batch()
|
| 785 |
-
except RuntimeError as e:
|
| 786 |
-
if "No executable batch size found" in str(e):
|
| 787 |
-
batch_size = 1
|
| 788 |
-
else:
|
| 789 |
-
raise
|
| 790 |
-
|
| 791 |
-
if self.world_size > 1:
|
| 792 |
-
# if multi-GPU, always take minimum over all selected batch sizes
|
| 793 |
-
max_rnk_bs = torch.tensor([batch_size], device=self.device)
|
| 794 |
-
gathered = (
|
| 795 |
-
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
|
| 796 |
-
)
|
| 797 |
-
batch_size = min(gathered)
|
| 798 |
-
clear_torch_cache()
|
| 799 |
-
return batch_size
|
| 800 |
-
|
| 801 |
-
clear_torch_cache()
|
| 802 |
-
return batch_size
|
| 803 |
-
|
| 804 |
-
def tok_encode(
|
| 805 |
-
self, string: str, left_truncate_len=None, add_special_tokens=None
|
| 806 |
-
) -> List[int]:
|
| 807 |
-
""" """
|
| 808 |
-
# default for None - empty dict, use predefined tokenizer param
|
| 809 |
-
# used for all models except for CausalLM or predefined value
|
| 810 |
-
special_tokens_kwargs = {}
|
| 811 |
-
|
| 812 |
-
# by default for CausalLM - false or self.add_bos_token is set
|
| 813 |
-
if add_special_tokens is None:
|
| 814 |
-
if self.backend == "causal":
|
| 815 |
-
special_tokens_kwargs = {
|
| 816 |
-
"add_special_tokens": False or self.add_bos_token
|
| 817 |
-
}
|
| 818 |
-
# otherwise the method explicitly defines the value
|
| 819 |
-
else:
|
| 820 |
-
special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
|
| 821 |
-
|
| 822 |
-
encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
|
| 823 |
-
|
| 824 |
-
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
|
| 825 |
-
if left_truncate_len:
|
| 826 |
-
encoding = encoding[-left_truncate_len:]
|
| 827 |
-
|
| 828 |
-
return encoding
|
| 829 |
-
|
| 830 |
-
def tok_batch_encode(
|
| 831 |
-
self,
|
| 832 |
-
strings: List[str],
|
| 833 |
-
padding_side: str = "left",
|
| 834 |
-
left_truncate_len: int = None,
|
| 835 |
-
truncation: bool = False,
|
| 836 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 837 |
-
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
|
| 838 |
-
old_padding_side = self.tokenizer.padding_side
|
| 839 |
-
self.tokenizer.padding_side = padding_side
|
| 840 |
-
|
| 841 |
-
add_special_tokens = {}
|
| 842 |
-
if self.backend == "causal":
|
| 843 |
-
add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
|
| 844 |
-
|
| 845 |
-
encoding = self.tokenizer(
|
| 846 |
-
strings,
|
| 847 |
-
truncation=truncation,
|
| 848 |
-
padding="longest",
|
| 849 |
-
return_tensors="pt",
|
| 850 |
-
**add_special_tokens,
|
| 851 |
-
)
|
| 852 |
-
if left_truncate_len:
|
| 853 |
-
original_lengths = encoding["input_ids"].size(1)
|
| 854 |
-
if original_lengths > left_truncate_len:
|
| 855 |
-
eval_logger.warn(
|
| 856 |
-
f"Left truncation applied. Original sequence length was {original_lengths}, "
|
| 857 |
-
f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
|
| 858 |
-
)
|
| 859 |
-
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
|
| 860 |
-
encoding["attention_mask"] = encoding["attention_mask"][
|
| 861 |
-
:, -left_truncate_len:
|
| 862 |
-
]
|
| 863 |
-
self.tokenizer.padding_side = old_padding_side
|
| 864 |
-
|
| 865 |
-
return encoding["input_ids"], encoding["attention_mask"]
|
| 866 |
-
|
| 867 |
-
def tok_decode(self, tokens, skip_special_tokens=True):
|
| 868 |
-
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
|
| 869 |
-
|
| 870 |
-
def _model_call(self, inps, attn_mask=None, labels=None):
|
| 871 |
-
"""
|
| 872 |
-
:param inps: torch.Tensor
|
| 873 |
-
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
|
| 874 |
-
[batch, sequence_ctx]. the size of sequence may vary from call to call
|
| 875 |
-
:param attn_mask: torch.Tensor, optional
|
| 876 |
-
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
| 877 |
-
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
| 878 |
-
:param labels: torch.Tensor, optional
|
| 879 |
-
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed
|
| 880 |
-
(and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM
|
| 881 |
-
:return
|
| 882 |
-
A torch tensor of shape [batch, sequence, vocab] with the
|
| 883 |
-
logits returned from the model's decoder
|
| 884 |
-
"""
|
| 885 |
-
with torch.no_grad():
|
| 886 |
-
if attn_mask is not None or labels is not None:
|
| 887 |
-
assert attn_mask is not None and labels is not None
|
| 888 |
-
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
|
| 889 |
-
return self.model(
|
| 890 |
-
input_ids=inps, attention_mask=attn_mask, labels=labels
|
| 891 |
-
).logits
|
| 892 |
-
else:
|
| 893 |
-
assert self.AUTO_MODEL_CLASS in (
|
| 894 |
-
transformers.AutoModelForCausalLM,
|
| 895 |
-
transformers.AutoModelForVision2Seq,
|
| 896 |
-
)
|
| 897 |
-
return self.model(inps).logits
|
| 898 |
-
|
| 899 |
-
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
| 900 |
-
# temperature = 0.0 if not set
|
| 901 |
-
# if do_sample is false and temp==0.0:
|
| 902 |
-
# remove temperature, as do_sample=False takes care of this
|
| 903 |
-
# and we don't want a warning from HF
|
| 904 |
-
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
|
| 905 |
-
do_sample = generation_kwargs.get("do_sample", None)
|
| 906 |
-
|
| 907 |
-
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
|
| 908 |
-
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
|
| 909 |
-
generation_kwargs["do_sample"] = do_sample = False
|
| 910 |
-
|
| 911 |
-
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
|
| 912 |
-
generation_kwargs.pop("temperature")
|
| 913 |
-
# build stopping criteria
|
| 914 |
-
stopping_criteria = stop_sequences_criteria(
|
| 915 |
-
self.tokenizer, stop, context.shape[1], context.shape[0]
|
| 916 |
-
)
|
| 917 |
-
return self.model.generate(
|
| 918 |
-
input_ids=context,
|
| 919 |
-
max_length=max_length,
|
| 920 |
-
stopping_criteria=stopping_criteria,
|
| 921 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 922 |
-
use_cache=True,
|
| 923 |
-
**generation_kwargs,
|
| 924 |
-
)
|
| 925 |
-
|
| 926 |
-
def _select_cont_toks(
|
| 927 |
-
self, logits: torch.Tensor, contlen: int = None, inplen: int = None
|
| 928 |
-
) -> torch.Tensor:
|
| 929 |
-
if self.backend == "causal":
|
| 930 |
-
assert contlen and inplen, (
|
| 931 |
-
"Must pass input len and cont. len to select scored logits for causal LM"
|
| 932 |
-
)
|
| 933 |
-
# discard right-padding.
|
| 934 |
-
# also discard the input/context tokens. we'll only score continuations.
|
| 935 |
-
logits = logits[inplen - contlen : inplen]
|
| 936 |
-
elif self.backend == "seq2seq":
|
| 937 |
-
assert contlen and not inplen, (
|
| 938 |
-
"Selecting scored logits for Seq2SeqLM requires only cont. len"
|
| 939 |
-
)
|
| 940 |
-
# only discard right-padding.
|
| 941 |
-
# the logits input to this fn only contain decoder-side tokens.
|
| 942 |
-
logits = logits[:contlen]
|
| 943 |
-
|
| 944 |
-
return logits
|
| 945 |
-
|
| 946 |
-
def loglikelihood_rolling(
|
| 947 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 948 |
-
) -> List[float]:
|
| 949 |
-
adaptive_batch_size = None
|
| 950 |
-
if self.batch_size == "auto":
|
| 951 |
-
# using rolling window with maximum context
|
| 952 |
-
print("Passed argument batch_size = auto. Detecting largest batch size")
|
| 953 |
-
batch_size = self._detect_batch_size()
|
| 954 |
-
print(f"Determined Largest batch size: {batch_size}")
|
| 955 |
-
adaptive_batch_size = batch_size
|
| 956 |
-
|
| 957 |
-
# First, collect all windows from all requests
|
| 958 |
-
all_windows = [] # List of (request_idx, window) tuples
|
| 959 |
-
request_window_counts = [] # Track number of windows per request
|
| 960 |
-
|
| 961 |
-
for req_idx, (string,) in enumerate(
|
| 962 |
-
tqdm(
|
| 963 |
-
[req.args for req in requests],
|
| 964 |
-
disable=(disable_tqdm or (self.rank != 0)),
|
| 965 |
-
)
|
| 966 |
-
):
|
| 967 |
-
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
|
| 968 |
-
map(
|
| 969 |
-
utils.make_disjoint_window,
|
| 970 |
-
utils.get_rolling_token_windows(
|
| 971 |
-
token_list=self.tok_encode(string),
|
| 972 |
-
prefix_token=self.prefix_token_id,
|
| 973 |
-
max_seq_len=self.max_length,
|
| 974 |
-
context_len=1,
|
| 975 |
-
),
|
| 976 |
-
)
|
| 977 |
-
)
|
| 978 |
-
|
| 979 |
-
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
|
| 980 |
-
windows = [(None,) + x for x in rolling_token_windows]
|
| 981 |
-
|
| 982 |
-
# Store windows with their request index
|
| 983 |
-
all_windows.extend((req_idx, window) for window in windows)
|
| 984 |
-
request_window_counts.append(len(windows))
|
| 985 |
-
|
| 986 |
-
# Handle distributed case padding
|
| 987 |
-
pad_amnt = 0
|
| 988 |
-
if self.world_size > 1:
|
| 989 |
-
mytensor = torch.tensor(len(all_windows), device=self.device)
|
| 990 |
-
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
|
| 991 |
-
pad_amnt = max(gathered) - gathered[self.rank]
|
| 992 |
-
if pad_amnt > 0:
|
| 993 |
-
all_windows += pad_amnt * [all_windows[0]]
|
| 994 |
-
|
| 995 |
-
all_nlls = []
|
| 996 |
-
batch_size = adaptive_batch_size or self.batch_size
|
| 997 |
-
for i in range(0, len(all_windows), batch_size):
|
| 998 |
-
batch = all_windows[i : i + batch_size]
|
| 999 |
-
# Extract just the windows for processing, keeping track of request indices
|
| 1000 |
-
batch_indices, batch_windows = zip(*batch)
|
| 1001 |
-
|
| 1002 |
-
batch_nlls = self._loglikelihood_tokens(
|
| 1003 |
-
requests=batch_windows,
|
| 1004 |
-
disable_tqdm=False,
|
| 1005 |
-
override_bs=len(batch_windows),
|
| 1006 |
-
)
|
| 1007 |
-
# Store results with their request indices
|
| 1008 |
-
all_nlls.extend(zip(batch_indices, batch_nlls))
|
| 1009 |
-
|
| 1010 |
-
# Remove padding if necessary
|
| 1011 |
-
if (self.world_size > 1) and (pad_amnt > 0):
|
| 1012 |
-
all_nlls = all_nlls[:-pad_amnt]
|
| 1013 |
-
|
| 1014 |
-
# Reconstruct per-request loglikelihoods
|
| 1015 |
-
loglikelihoods = []
|
| 1016 |
-
current_idx = 0
|
| 1017 |
-
for window_count in request_window_counts:
|
| 1018 |
-
# Get all nlls for this request
|
| 1019 |
-
request_nlls = all_nlls[current_idx : current_idx + window_count]
|
| 1020 |
-
# Sum up the nlls for this request (discarding is_greedy)
|
| 1021 |
-
request_total = sum(nll[0] for _, nll in request_nlls)
|
| 1022 |
-
loglikelihoods.append(request_total)
|
| 1023 |
-
current_idx += window_count
|
| 1024 |
-
|
| 1025 |
-
string = requests[len(loglikelihoods) - 1].args[0]
|
| 1026 |
-
self.cache_hook.add_partial(
|
| 1027 |
-
"loglikelihood_rolling", (string,), request_total
|
| 1028 |
-
)
|
| 1029 |
-
|
| 1030 |
-
return loglikelihoods
|
| 1031 |
-
|
| 1032 |
-
def _batch_scheduler(self, pos, n_reordered_requests):
|
| 1033 |
-
sched = pos // int(len(n_reordered_requests) / self.batch_schedule)
|
| 1034 |
-
if sched in self.batch_sizes:
|
| 1035 |
-
return self.batch_sizes[sched]
|
| 1036 |
-
if (len(self.batch_sizes) > 1) and (
|
| 1037 |
-
self.batch_sizes[sched - 1] == self.max_batch_size
|
| 1038 |
-
):
|
| 1039 |
-
# if previous batch size is already maximal, skip recomputation
|
| 1040 |
-
self.batch_sizes[sched] = self.max_batch_size
|
| 1041 |
-
return self.batch_sizes[sched]
|
| 1042 |
-
print(
|
| 1043 |
-
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
|
| 1044 |
-
)
|
| 1045 |
-
self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos)
|
| 1046 |
-
print(f"Determined largest batch size: {self.batch_sizes[sched]}")
|
| 1047 |
-
return self.batch_sizes[sched]
|
| 1048 |
-
|
| 1049 |
-
def _loglikelihood_tokens(
|
| 1050 |
-
self,
|
| 1051 |
-
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
|
| 1052 |
-
disable_tqdm: bool = False,
|
| 1053 |
-
override_bs: int = None,
|
| 1054 |
-
) -> List[Tuple[float, bool]]:
|
| 1055 |
-
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
|
| 1056 |
-
res = []
|
| 1057 |
-
|
| 1058 |
-
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 1059 |
-
"""Defines the key for the sorted method"""
|
| 1060 |
-
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 1061 |
-
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 1062 |
-
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 1063 |
-
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 1064 |
-
# automatic adaptive batches much much easier to implement
|
| 1065 |
-
# - any OOMs will happen right away rather than near the end
|
| 1066 |
-
|
| 1067 |
-
toks = req[1] + req[2]
|
| 1068 |
-
return -len(toks), tuple(toks)
|
| 1069 |
-
|
| 1070 |
-
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
|
| 1071 |
-
"""Defines the key to group and lookup one-token continuations"""
|
| 1072 |
-
# Use with group_by="contexts" (optional)"
|
| 1073 |
-
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
|
| 1074 |
-
# speeds up some multiple-choice tasks proportionally to the number of choices.
|
| 1075 |
-
# groups requests by context+continuation[:-1] and infer on one request/group.
|
| 1076 |
-
return req[-2] + req[-1][:-1]
|
| 1077 |
-
|
| 1078 |
-
re_ord = Collator(
|
| 1079 |
-
requests,
|
| 1080 |
-
sort_fn=_collate,
|
| 1081 |
-
group_by="contexts"
|
| 1082 |
-
if self.backend == "causal" and self.logits_cache
|
| 1083 |
-
else None,
|
| 1084 |
-
group_fn=_lookup_one_token_cont,
|
| 1085 |
-
)
|
| 1086 |
-
|
| 1087 |
-
# automatic (variable) batch size detection for vectorization
|
| 1088 |
-
# pull longest context sample from request
|
| 1089 |
-
n_reordered_requests = len(re_ord)
|
| 1090 |
-
batch_size = (
|
| 1091 |
-
self.batch_size
|
| 1092 |
-
if self.batch_size != "auto"
|
| 1093 |
-
else override_bs
|
| 1094 |
-
if override_bs is not None
|
| 1095 |
-
else 0
|
| 1096 |
-
)
|
| 1097 |
-
batch_fn = (
|
| 1098 |
-
self._batch_scheduler
|
| 1099 |
-
if self.batch_size == "auto"
|
| 1100 |
-
and n_reordered_requests > 0
|
| 1101 |
-
and not override_bs
|
| 1102 |
-
else None
|
| 1103 |
-
)
|
| 1104 |
-
|
| 1105 |
-
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 1106 |
-
pbar = tqdm(
|
| 1107 |
-
total=len(requests),
|
| 1108 |
-
disable=(disable_tqdm or (self.rank != 0)),
|
| 1109 |
-
desc="Running loglikelihood requests",
|
| 1110 |
-
)
|
| 1111 |
-
for chunk in chunks:
|
| 1112 |
-
inps = []
|
| 1113 |
-
cont_toks_list = []
|
| 1114 |
-
inplens = []
|
| 1115 |
-
|
| 1116 |
-
conts = []
|
| 1117 |
-
encoder_attns = []
|
| 1118 |
-
|
| 1119 |
-
padding_len_inp = None
|
| 1120 |
-
padding_len_cont = None
|
| 1121 |
-
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
| 1122 |
-
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
| 1123 |
-
# again because vectorizing is annoying
|
| 1124 |
-
|
| 1125 |
-
for _, context_enc, continuation_enc in chunk:
|
| 1126 |
-
# sanity check
|
| 1127 |
-
assert len(context_enc) > 0
|
| 1128 |
-
assert len(continuation_enc) > 0
|
| 1129 |
-
assert len(continuation_enc) <= self.max_length
|
| 1130 |
-
|
| 1131 |
-
# how this all works (illustrated on a causal decoder-only setup):
|
| 1132 |
-
# CTX CONT
|
| 1133 |
-
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
| 1134 |
-
# model \ \
|
| 1135 |
-
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
| 1136 |
-
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
| 1137 |
-
|
| 1138 |
-
# when too long to fit in context, truncate from the left
|
| 1139 |
-
if self.backend == "causal":
|
| 1140 |
-
total_length = len(context_enc) + len(continuation_enc)
|
| 1141 |
-
if total_length > self.max_length + 1:
|
| 1142 |
-
eval_logger.warning(
|
| 1143 |
-
f"Combined length of context ({len(context_enc)}) and continuation ({len(continuation_enc)}) "
|
| 1144 |
-
f"exceeds model's maximum length ({self.max_length}). "
|
| 1145 |
-
f"Truncating {total_length - self.max_length + 1} tokens from the left."
|
| 1146 |
-
)
|
| 1147 |
-
inp = torch.tensor(
|
| 1148 |
-
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
| 1149 |
-
dtype=torch.long,
|
| 1150 |
-
device=self.device,
|
| 1151 |
-
)
|
| 1152 |
-
(inplen,) = inp.shape
|
| 1153 |
-
elif self.backend == "seq2seq":
|
| 1154 |
-
inp = torch.tensor(
|
| 1155 |
-
(context_enc)[-self.max_length :],
|
| 1156 |
-
dtype=torch.long,
|
| 1157 |
-
device=self.device,
|
| 1158 |
-
)
|
| 1159 |
-
(inplen,) = inp.shape
|
| 1160 |
-
|
| 1161 |
-
# build encoder attn masks
|
| 1162 |
-
encoder_attns.append(torch.ones_like(inp))
|
| 1163 |
-
|
| 1164 |
-
cont = torch.tensor(
|
| 1165 |
-
(continuation_enc)[-self.max_length :],
|
| 1166 |
-
# TODO: left-shift these?
|
| 1167 |
-
# TODO: our code assumes we never end up truncating conts for either model type
|
| 1168 |
-
dtype=torch.long,
|
| 1169 |
-
device=self.device,
|
| 1170 |
-
)
|
| 1171 |
-
(contlen,) = cont.shape
|
| 1172 |
-
|
| 1173 |
-
conts.append(cont)
|
| 1174 |
-
|
| 1175 |
-
padding_len_cont = (
|
| 1176 |
-
max(padding_len_cont, contlen)
|
| 1177 |
-
if padding_len_cont is not None
|
| 1178 |
-
else contlen
|
| 1179 |
-
)
|
| 1180 |
-
|
| 1181 |
-
padding_len_inp = (
|
| 1182 |
-
max(padding_len_inp, inplen)
|
| 1183 |
-
if padding_len_inp is not None
|
| 1184 |
-
else inplen
|
| 1185 |
-
)
|
| 1186 |
-
|
| 1187 |
-
inps.append(inp) # [1, inp_length]
|
| 1188 |
-
cont_toks_list.append(continuation_enc)
|
| 1189 |
-
inplens.append(inplen)
|
| 1190 |
-
|
| 1191 |
-
# create encoder attn mask and batched conts, if seq2seq
|
| 1192 |
-
call_kwargs = {}
|
| 1193 |
-
if self.backend == "causal":
|
| 1194 |
-
batched_inps = pad_and_concat(
|
| 1195 |
-
padding_len_inp, inps, padding_side="right"
|
| 1196 |
-
) # [batch, padding_len_inp]
|
| 1197 |
-
elif self.backend == "seq2seq":
|
| 1198 |
-
# TODO: left-pad encoder inps and mask?
|
| 1199 |
-
batched_inps = pad_and_concat(
|
| 1200 |
-
padding_len_inp, inps
|
| 1201 |
-
) # [batch, padding_len_inp]
|
| 1202 |
-
batched_conts = pad_and_concat(
|
| 1203 |
-
padding_len_cont, conts
|
| 1204 |
-
) # [batch, padding_len_cont]
|
| 1205 |
-
batched_encoder_mask = pad_and_concat(
|
| 1206 |
-
padding_len_inp, encoder_attns
|
| 1207 |
-
) # [batch, padding_len_inp]
|
| 1208 |
-
call_kwargs = {
|
| 1209 |
-
"attn_mask": batched_encoder_mask,
|
| 1210 |
-
"labels": batched_conts,
|
| 1211 |
-
}
|
| 1212 |
-
|
| 1213 |
-
multi_logits = F.log_softmax(
|
| 1214 |
-
self._model_call(batched_inps, **call_kwargs),
|
| 1215 |
-
dim=-1,
|
| 1216 |
-
dtype=self.softmax_dtype,
|
| 1217 |
-
) # [batch, padding_length (inp or cont), vocab]
|
| 1218 |
-
|
| 1219 |
-
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
|
| 1220 |
-
chunk, multi_logits, inplens, cont_toks_list
|
| 1221 |
-
):
|
| 1222 |
-
# Slice to original seq length
|
| 1223 |
-
contlen = len(cont_toks)
|
| 1224 |
-
# take only logits in the continuation
|
| 1225 |
-
# (discard context toks if decoder-only ; discard right-padding)
|
| 1226 |
-
# also discards + checks for "virtual tokens" in the causal LM's input window
|
| 1227 |
-
# from prompt/prefix tuning tokens, if applicable
|
| 1228 |
-
ctx_len = (
|
| 1229 |
-
inplen + (logits.shape[0] - padding_len_inp)
|
| 1230 |
-
if self.backend == "causal"
|
| 1231 |
-
else None
|
| 1232 |
-
)
|
| 1233 |
-
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
|
| 1234 |
-
logits = logits.unsqueeze(0) # [1, seq, vocab]
|
| 1235 |
-
|
| 1236 |
-
# Check if per-token argmax is exactly equal to continuation
|
| 1237 |
-
greedy_tokens = logits.argmax(dim=-1)
|
| 1238 |
-
|
| 1239 |
-
# check for one-token continuation cache hits.
|
| 1240 |
-
# noop in case group_by != "contexts" or no cache hit and returns the
|
| 1241 |
-
# original args. Otherwise, expands the logits batch dimension and yields each
|
| 1242 |
-
# batch along with matching continuation tokens and prompt strings.
|
| 1243 |
-
# logits -> [1, seq, vocab]
|
| 1244 |
-
for request_str, cont_toks, logits in re_ord.get_cache(
|
| 1245 |
-
req_str=request_str,
|
| 1246 |
-
cxt_toks=ctx_tokens,
|
| 1247 |
-
cont_toks=cont_toks,
|
| 1248 |
-
logits=logits,
|
| 1249 |
-
):
|
| 1250 |
-
cont_toks = torch.tensor(
|
| 1251 |
-
cont_toks, dtype=torch.long, device=self.device
|
| 1252 |
-
).unsqueeze(0) # [1, seq]
|
| 1253 |
-
# Use trailing slice [-cont_toks.shape[1]:] to handle variable length cont_len (but same ctx+cont[:-1]).
|
| 1254 |
-
# i.e. continuations can be sliced at diff points. Collator ensures we have sufficient greedy_tokens
|
| 1255 |
-
# by choosing key with longest cont if group_by="contexts".
|
| 1256 |
-
max_equal = (
|
| 1257 |
-
greedy_tokens[:, -cont_toks.shape[1] :] == cont_toks
|
| 1258 |
-
).all()
|
| 1259 |
-
|
| 1260 |
-
# Obtain log-probs at the corresponding continuation token indices
|
| 1261 |
-
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
| 1262 |
-
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
| 1263 |
-
-1
|
| 1264 |
-
) # [1, seq]
|
| 1265 |
-
|
| 1266 |
-
# Answer: (log prob, is-exact-match)
|
| 1267 |
-
answer = (float(logits.sum()), bool(max_equal))
|
| 1268 |
-
|
| 1269 |
-
res.append(answer)
|
| 1270 |
-
|
| 1271 |
-
if request_str is not None:
|
| 1272 |
-
# special case: loglikelihood_rolling produces a number of loglikelihood requests
|
| 1273 |
-
# all with cache key None. instead do add_partial on the per-example level
|
| 1274 |
-
# in the loglikelihood_rolling() function for those.
|
| 1275 |
-
self.cache_hook.add_partial(
|
| 1276 |
-
"loglikelihood", request_str, answer
|
| 1277 |
-
)
|
| 1278 |
-
pbar.update(1)
|
| 1279 |
-
|
| 1280 |
-
pbar.close()
|
| 1281 |
-
|
| 1282 |
-
return re_ord.get_original(res)
|
| 1283 |
-
|
| 1284 |
-
def generate_until(
|
| 1285 |
-
self, requests: List[Instance], disable_tqdm: bool = False
|
| 1286 |
-
) -> List[str]:
|
| 1287 |
-
res = []
|
| 1288 |
-
|
| 1289 |
-
def _collate(req: Tuple[str, dict]):
|
| 1290 |
-
"""Defines the key for the sorted method"""
|
| 1291 |
-
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 1292 |
-
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 1293 |
-
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 1294 |
-
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 1295 |
-
# automatic adaptive batches much much easier to implement
|
| 1296 |
-
# - any OOMs will happen right away rather than near the end
|
| 1297 |
-
toks = self.tok_encode(req[0])
|
| 1298 |
-
return -len(toks), req[0]
|
| 1299 |
-
|
| 1300 |
-
pbar = tqdm(
|
| 1301 |
-
total=len(requests),
|
| 1302 |
-
disable=(disable_tqdm or (self.rank != 0)),
|
| 1303 |
-
desc="Running generate_until requests",
|
| 1304 |
-
)
|
| 1305 |
-
adaptive_batch_size = None
|
| 1306 |
-
if self.batch_size == "auto":
|
| 1307 |
-
# using rolling window with maximum context
|
| 1308 |
-
print("Passed argument batch_size = auto. Detecting largest batch size")
|
| 1309 |
-
batch_size = self._detect_batch_size()
|
| 1310 |
-
print(f"Determined Largest batch size: {batch_size}")
|
| 1311 |
-
adaptive_batch_size = batch_size
|
| 1312 |
-
# for each different set of kwargs, we execute all requests, by batch.
|
| 1313 |
-
batch_size = (
|
| 1314 |
-
self.batch_size
|
| 1315 |
-
if self.batch_size != "auto"
|
| 1316 |
-
else adaptive_batch_size
|
| 1317 |
-
if adaptive_batch_size is not None
|
| 1318 |
-
else 0
|
| 1319 |
-
)
|
| 1320 |
-
batch_fn = (
|
| 1321 |
-
self._batch_scheduler
|
| 1322 |
-
if self.batch_size == "auto" and not adaptive_batch_size
|
| 1323 |
-
else None
|
| 1324 |
-
)
|
| 1325 |
-
|
| 1326 |
-
# we group requests by their generation_kwargs,
|
| 1327 |
-
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
|
| 1328 |
-
# in the same batch.
|
| 1329 |
-
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
|
| 1330 |
-
re_ords = Collator(
|
| 1331 |
-
[reg.args for reg in requests],
|
| 1332 |
-
sort_fn=_collate,
|
| 1333 |
-
group_by="gen_kwargs",
|
| 1334 |
-
group_fn=lambda x: x[1],
|
| 1335 |
-
)
|
| 1336 |
-
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
|
| 1337 |
-
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
|
| 1338 |
-
for chunk in chunks:
|
| 1339 |
-
contexts, all_gen_kwargs = zip(*chunk)
|
| 1340 |
-
# we assume all gen kwargs in the batch are the same
|
| 1341 |
-
# this is safe to assume because the `grouper` object ensures it.
|
| 1342 |
-
gen_kwargs = all_gen_kwargs[0]
|
| 1343 |
-
# unpack our keyword arguments.
|
| 1344 |
-
if isinstance(gen_kwargs, dict):
|
| 1345 |
-
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
|
| 1346 |
-
# add EOS token to stop sequences
|
| 1347 |
-
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
|
| 1348 |
-
else:
|
| 1349 |
-
raise ValueError(
|
| 1350 |
-
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
|
| 1351 |
-
)
|
| 1352 |
-
if "max_gen_toks" in kwargs.keys():
|
| 1353 |
-
max_gen_toks = kwargs.pop("max_gen_toks")
|
| 1354 |
-
else:
|
| 1355 |
-
max_gen_toks = self.max_gen_toks
|
| 1356 |
-
|
| 1357 |
-
# set the max length in tokens of inputs ("context_enc")
|
| 1358 |
-
if self.backend == "causal":
|
| 1359 |
-
# max len for inputs = max length, minus room to generate the max new tokens
|
| 1360 |
-
max_ctx_len = self.max_length - max_gen_toks
|
| 1361 |
-
assert max_ctx_len > 0, (
|
| 1362 |
-
f"Invalid configuration: requested max tokens to generate ({max_gen_toks}) must be less than model's maximum sequence length ({self.max_length})."
|
| 1363 |
-
)
|
| 1364 |
-
elif self.backend == "seq2seq":
|
| 1365 |
-
# max len for inputs = encoder's whole max_length
|
| 1366 |
-
max_ctx_len = self.max_length
|
| 1367 |
-
|
| 1368 |
-
# encode, pad, and truncate contexts for this batch
|
| 1369 |
-
context_enc, attn_masks = self.tok_batch_encode(
|
| 1370 |
-
contexts,
|
| 1371 |
-
left_truncate_len=max_ctx_len,
|
| 1372 |
-
truncation=self.truncation,
|
| 1373 |
-
)
|
| 1374 |
-
context_enc = context_enc.to(self.device)
|
| 1375 |
-
attn_masks = attn_masks.to(self.device)
|
| 1376 |
-
|
| 1377 |
-
if "max_length" not in kwargs:
|
| 1378 |
-
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
|
| 1379 |
-
|
| 1380 |
-
# perform batched generation
|
| 1381 |
-
cont = self._model_generate(
|
| 1382 |
-
context=context_enc,
|
| 1383 |
-
attention_mask=attn_masks,
|
| 1384 |
-
stop=until,
|
| 1385 |
-
**kwargs,
|
| 1386 |
-
)
|
| 1387 |
-
|
| 1388 |
-
cont_toks_list = cont.tolist()
|
| 1389 |
-
for cont_toks, context in zip(cont_toks_list, contexts):
|
| 1390 |
-
# discard context + left-padding toks if using causal decoder-only LM
|
| 1391 |
-
if self.backend == "causal":
|
| 1392 |
-
cont_toks = cont_toks[context_enc.shape[1] :]
|
| 1393 |
-
|
| 1394 |
-
s = self.tok_decode(cont_toks)
|
| 1395 |
-
|
| 1396 |
-
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
|
| 1397 |
-
for term in until:
|
| 1398 |
-
if len(term) > 0:
|
| 1399 |
-
# ignore '' separator,
|
| 1400 |
-
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
|
| 1401 |
-
s = s.split(term)[0]
|
| 1402 |
-
|
| 1403 |
-
res.append(s)
|
| 1404 |
-
|
| 1405 |
-
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
|
| 1406 |
-
pbar.update(1)
|
| 1407 |
-
# reorder this group of results back to original unsorted form
|
| 1408 |
-
res = re_ords.get_original(res)
|
| 1409 |
-
|
| 1410 |
-
pbar.close()
|
| 1411 |
-
|
| 1412 |
-
return res
|
| 1413 |
-
|
| 1414 |
-
def apply_chat_template(
|
| 1415 |
-
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
|
| 1416 |
-
) -> str:
|
| 1417 |
-
"""
|
| 1418 |
-
Method to apply a chat template to a list of chat history between user and model.
|
| 1419 |
-
"""
|
| 1420 |
-
try:
|
| 1421 |
-
chat_templated = self.tokenizer.apply_chat_template(
|
| 1422 |
-
chat_history,
|
| 1423 |
-
tokenize=False,
|
| 1424 |
-
add_generation_prompt=add_generation_prompt,
|
| 1425 |
-
continue_final_message=not add_generation_prompt,
|
| 1426 |
-
)
|
| 1427 |
-
except jinja2.exceptions.TemplateError:
|
| 1428 |
-
eval_logger.warning(
|
| 1429 |
-
"Failed to apply chat template. removing the system role in chat history."
|
| 1430 |
-
)
|
| 1431 |
-
chat_history = [msg for msg in chat_history if msg["role"] != "system"]
|
| 1432 |
-
chat_templated = self.tokenizer.apply_chat_template(
|
| 1433 |
-
chat_history,
|
| 1434 |
-
tokenize=False,
|
| 1435 |
-
add_generation_prompt=add_generation_prompt,
|
| 1436 |
-
continue_final_message=not add_generation_prompt,
|
| 1437 |
-
)
|
| 1438 |
-
|
| 1439 |
-
return chat_templated
|
| 1440 |
-
|
| 1441 |
-
def get_model_info(self) -> dict:
|
| 1442 |
-
"""
|
| 1443 |
-
Method to get Hugging Face model information for experiment reproducibility.
|
| 1444 |
-
"""
|
| 1445 |
-
|
| 1446 |
-
def get_model_num_params(model) -> int:
|
| 1447 |
-
if hasattr(model, "num_parameters"):
|
| 1448 |
-
return model.num_parameters()
|
| 1449 |
-
if hasattr(model, "parameters"):
|
| 1450 |
-
return sum(p.numel() for p in model.parameters())
|
| 1451 |
-
else:
|
| 1452 |
-
return -1
|
| 1453 |
-
|
| 1454 |
-
def get_model_dtype(model) -> str:
|
| 1455 |
-
if hasattr(model, "dtype"):
|
| 1456 |
-
return model.dtype
|
| 1457 |
-
else:
|
| 1458 |
-
return ""
|
| 1459 |
-
|
| 1460 |
-
def get_model_sha(pretrained: str, revision: str) -> str:
|
| 1461 |
-
try:
|
| 1462 |
-
model_info = HfApi().model_info(repo_id=pretrained, revision=revision)
|
| 1463 |
-
return model_info.sha
|
| 1464 |
-
except Exception as e:
|
| 1465 |
-
eval_logger.debug(
|
| 1466 |
-
f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}"
|
| 1467 |
-
)
|
| 1468 |
-
return ""
|
| 1469 |
-
|
| 1470 |
-
model_info = {
|
| 1471 |
-
"model_num_parameters": get_model_num_params(self._model),
|
| 1472 |
-
"model_dtype": get_model_dtype(self._model),
|
| 1473 |
-
"model_revision": self.revision,
|
| 1474 |
-
"model_sha": get_model_sha(self.pretrained, self.revision),
|
| 1475 |
-
}
|
| 1476 |
-
if self.peft:
|
| 1477 |
-
model_info["peft_sha"] = get_model_sha(self.peft, self.revision)
|
| 1478 |
-
if self.delta:
|
| 1479 |
-
model_info["delta_sha"] = get_model_sha(self.delta, self.revision)
|
| 1480 |
-
return model_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lm-evaluation-harness/lm_eval/models/ibm_watsonx_ai.py
DELETED
|
@@ -1,445 +0,0 @@
|
|
| 1 |
-
import copy
|
| 2 |
-
import json
|
| 3 |
-
import logging
|
| 4 |
-
import os
|
| 5 |
-
import warnings
|
| 6 |
-
from functools import lru_cache
|
| 7 |
-
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
|
| 8 |
-
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
|
| 11 |
-
from lm_eval.api.instance import Instance
|
| 12 |
-
from lm_eval.api.model import LM
|
| 13 |
-
from lm_eval.api.registry import register_model
|
| 14 |
-
from lm_eval.models.api_models import JsonChatStr
|
| 15 |
-
from lm_eval.utils import simple_parse_args_string
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
eval_logger = logging.getLogger(__name__)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class LogLikelihoodResult(NamedTuple):
|
| 22 |
-
log_likelihood: float
|
| 23 |
-
is_greedy: bool
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _verify_credentials(creds: dict) -> None:
|
| 27 |
-
"""
|
| 28 |
-
Validate credentials for APIClient authentication.
|
| 29 |
-
|
| 30 |
-
Required conditions:
|
| 31 |
-
- Either ("username" and "password") or "apikey" must be present.
|
| 32 |
-
- "url" is mandatory.
|
| 33 |
-
- Either "project_id" or "space_id" must be present.
|
| 34 |
-
"""
|
| 35 |
-
env_var_map = {
|
| 36 |
-
"apikey": "WATSONX_API_KEY",
|
| 37 |
-
"token": "WATSONX_TOKEN",
|
| 38 |
-
"url": "WATSONX_URL",
|
| 39 |
-
"project_id": "WATSONX_PROJECT_ID",
|
| 40 |
-
"space_id": "WATSONX_SPACE_ID",
|
| 41 |
-
"username": "WATSONX_USERNAME",
|
| 42 |
-
"password": "WATSONX_PASSWORD",
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
# Check authentication: Either ("username" and "password") or "apikey" must be provided
|
| 46 |
-
has_auth = all(creds.get(key) for key in ["username", "password"]) or creds.get(
|
| 47 |
-
"apikey"
|
| 48 |
-
)
|
| 49 |
-
# Check required fields: "url" must be present
|
| 50 |
-
has_url = "url" in creds and creds["url"]
|
| 51 |
-
# Check project/space ID requirement: Either "project_id" or "space_id" must be present
|
| 52 |
-
has_project_or_space_id = any(creds.get(key) for key in ["project_id", "space_id"])
|
| 53 |
-
|
| 54 |
-
if not (has_auth and has_url and has_project_or_space_id):
|
| 55 |
-
missing_keys = []
|
| 56 |
-
if not has_auth:
|
| 57 |
-
missing_keys.append(
|
| 58 |
-
f"either ('username' and 'password') or 'apikey' ({env_var_map['apikey']})"
|
| 59 |
-
)
|
| 60 |
-
if not has_url:
|
| 61 |
-
missing_keys.append(f"url ({env_var_map['url']})")
|
| 62 |
-
if not has_project_or_space_id:
|
| 63 |
-
missing_keys.append(
|
| 64 |
-
f"either 'project_id' ({env_var_map['project_id']}) or 'space_id' ({env_var_map['space_id']})"
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
error_msg = f"Missing required credentials: {', '.join(missing_keys)}. "
|
| 68 |
-
error_msg += "Please set the environment variables indicated in parentheses."
|
| 69 |
-
raise ValueError(error_msg)
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
@lru_cache(maxsize=None)
|
| 73 |
-
def get_watsonx_credentials() -> Dict[str, str]:
|
| 74 |
-
"""
|
| 75 |
-
Retrieves Watsonx API credentials from environmental variables.
|
| 76 |
-
Returns:
|
| 77 |
-
Dict[str, str]: A dictionary containing the credentials necessary for authentication, including
|
| 78 |
-
keys such as `apikey` or `token`, `url`, and `project_id`.
|
| 79 |
-
Raises:
|
| 80 |
-
AssertionError: If the credentials format is invalid or any of the necessary credentials are missing.
|
| 81 |
-
"""
|
| 82 |
-
try:
|
| 83 |
-
from dotenv import load_dotenv
|
| 84 |
-
except ImportError:
|
| 85 |
-
raise ImportError(
|
| 86 |
-
"Could not import dotenv: Please install lm_eval[ibm_watsonx_ai] package."
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# This function attempts to load a file named .env starting from the CWD and working backwards
|
| 90 |
-
# towards root. KV pairs are parsed and stored as env vars iff not already set
|
| 91 |
-
load_dotenv()
|
| 92 |
-
|
| 93 |
-
credentials = {
|
| 94 |
-
"username": os.getenv("WATSONX_USERNAME", None),
|
| 95 |
-
"password": os.getenv("WATSONX_PASSWORD", None),
|
| 96 |
-
"apikey": os.getenv("WATSONX_API_KEY", None),
|
| 97 |
-
"token": os.getenv("WATSONX_TOKEN", None),
|
| 98 |
-
"url": os.getenv("WATSONX_URL", None),
|
| 99 |
-
"project_id": os.getenv("WATSONX_PROJECT_ID", None),
|
| 100 |
-
"space_id": os.getenv("WATSONX_SPACE_ID", None),
|
| 101 |
-
}
|
| 102 |
-
if "cloud.ibm.com" not in credentials["url"]:
|
| 103 |
-
credentials["instance_id"] = "openshift"
|
| 104 |
-
|
| 105 |
-
if all(credentials.get(key) for key in ["username", "password", "apikey"]):
|
| 106 |
-
warnings.warn(
|
| 107 |
-
"You're passing `username`, `password`, and `apikey` at the same time, "
|
| 108 |
-
"which might cause issues. More info on authentication in different scenarios "
|
| 109 |
-
"can be found in the docs: https://ibm.github.io/watsonx-ai-python-sdk/setup_cpd.html"
|
| 110 |
-
)
|
| 111 |
-
_verify_credentials(credentials)
|
| 112 |
-
return credentials
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
@register_model("watsonx_llm")
|
| 116 |
-
class WatsonxLLM(LM):
|
| 117 |
-
"""
|
| 118 |
-
Implementation of LM model interface for evaluating Watsonx model with the lm_eval framework.
|
| 119 |
-
See https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/model_guide.md for reference.
|
| 120 |
-
"""
|
| 121 |
-
|
| 122 |
-
@classmethod
|
| 123 |
-
def create_from_arg_string(
|
| 124 |
-
cls: Type["WatsonxLLM"],
|
| 125 |
-
arg_string: str,
|
| 126 |
-
additional_config: Optional[Dict] = None,
|
| 127 |
-
) -> "WatsonxLLM":
|
| 128 |
-
"""
|
| 129 |
-
Allow the user to specify model parameters (TextGenerationParameters) in CLI arguments.
|
| 130 |
-
"""
|
| 131 |
-
try:
|
| 132 |
-
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
| 133 |
-
except ImportError:
|
| 134 |
-
raise ImportError(
|
| 135 |
-
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
args = simple_parse_args_string(arg_string)
|
| 139 |
-
args.update(additional_config)
|
| 140 |
-
|
| 141 |
-
model_id = args.pop("model_id", None)
|
| 142 |
-
deployment_id = args.pop("deployment_id", None)
|
| 143 |
-
if model_id is None and deployment_id is None:
|
| 144 |
-
raise ValueError(
|
| 145 |
-
"'model_id' or 'deployment_id' is required, please pass it in 'model_args'"
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
if not args.get("do_sample", None):
|
| 149 |
-
args["temperature"] = None
|
| 150 |
-
args["top_p"] = None
|
| 151 |
-
args["top_k"] = None
|
| 152 |
-
args["seed"] = None
|
| 153 |
-
|
| 154 |
-
generate_params = {
|
| 155 |
-
GenParams.DECODING_METHOD: (
|
| 156 |
-
"greedy" if not args.get("do_sample", None) else "sample"
|
| 157 |
-
),
|
| 158 |
-
GenParams.LENGTH_PENALTY: args.get("length_penalty", None),
|
| 159 |
-
GenParams.TEMPERATURE: args.get("temperature", None),
|
| 160 |
-
GenParams.TOP_P: args.get("top_p", None),
|
| 161 |
-
GenParams.TOP_K: args.get("top_k", None),
|
| 162 |
-
GenParams.RANDOM_SEED: args.get("seed", None),
|
| 163 |
-
GenParams.REPETITION_PENALTY: args.get("repetition_penalty", None),
|
| 164 |
-
GenParams.MIN_NEW_TOKENS: args.get("min_new_tokens", None),
|
| 165 |
-
GenParams.MAX_NEW_TOKENS: args.get("max_new_tokens", 256),
|
| 166 |
-
GenParams.STOP_SEQUENCES: args.get("stop_sequences", None),
|
| 167 |
-
GenParams.TIME_LIMIT: args.get("time_limit", None),
|
| 168 |
-
GenParams.TRUNCATE_INPUT_TOKENS: args.get("truncate_input_tokens", None),
|
| 169 |
-
GenParams.RETURN_OPTIONS: {
|
| 170 |
-
"generated_tokens": True,
|
| 171 |
-
"input_tokens": True,
|
| 172 |
-
"token_logprobs": True,
|
| 173 |
-
"token_ranks": True,
|
| 174 |
-
},
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
generate_params = {k: v for k, v in generate_params.items() if v is not None}
|
| 178 |
-
|
| 179 |
-
return cls(
|
| 180 |
-
watsonx_credentials=get_watsonx_credentials(),
|
| 181 |
-
model_id=model_id,
|
| 182 |
-
deployment_id=deployment_id,
|
| 183 |
-
generate_params=generate_params,
|
| 184 |
-
)
|
| 185 |
-
|
| 186 |
-
def __init__(
|
| 187 |
-
self,
|
| 188 |
-
watsonx_credentials: Dict,
|
| 189 |
-
model_id,
|
| 190 |
-
deployment_id,
|
| 191 |
-
generate_params: Optional[Dict[Any, Any]] = None,
|
| 192 |
-
) -> None:
|
| 193 |
-
try:
|
| 194 |
-
from ibm_watsonx_ai import APIClient
|
| 195 |
-
from ibm_watsonx_ai.foundation_models import ModelInference
|
| 196 |
-
except ImportError:
|
| 197 |
-
raise ImportError(
|
| 198 |
-
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
|
| 199 |
-
)
|
| 200 |
-
super().__init__()
|
| 201 |
-
client = APIClient(watsonx_credentials)
|
| 202 |
-
project_id = watsonx_credentials.get("project_id", None)
|
| 203 |
-
client.set.default_project(project_id)
|
| 204 |
-
self.generate_params = generate_params
|
| 205 |
-
self.model = ModelInference(
|
| 206 |
-
model_id=model_id,
|
| 207 |
-
deployment_id=deployment_id,
|
| 208 |
-
api_client=client,
|
| 209 |
-
project_id=project_id,
|
| 210 |
-
)
|
| 211 |
-
self._model_id = model_id
|
| 212 |
-
|
| 213 |
-
@staticmethod
|
| 214 |
-
def _has_stop_token(response_tokens: List[str], context_tokens: List[str]) -> bool:
|
| 215 |
-
"""
|
| 216 |
-
Determines whether a stop token has been generated in the `response_tokens` compared to the `context_tokens`.
|
| 217 |
-
If the tokens do not match as expected, the function raises a RuntimeError, indicating a possible
|
| 218 |
-
misalignment between the tokens generated by the tokenizer and the model.
|
| 219 |
-
Args:
|
| 220 |
-
response_tokens (List[str]): The List of tokens generated as a response by the model.
|
| 221 |
-
context_tokens (List[str]): The List of tokens representing the input context.
|
| 222 |
-
Returns:
|
| 223 |
-
bool: True if the `response_tokens` likely contain a stop token that terminates the sequence,
|
| 224 |
-
otherwise raises an exception.
|
| 225 |
-
Raises:
|
| 226 |
-
RuntimeError: If there is an unexpected mismatch between the `response_tokens` and the `context_tokens`.
|
| 227 |
-
"""
|
| 228 |
-
context_length = len(context_tokens)
|
| 229 |
-
if response_tokens[: context_length - 1] == context_tokens[:-1]:
|
| 230 |
-
return (
|
| 231 |
-
response_tokens[-1] != context_tokens[-1]
|
| 232 |
-
) # only last token differs, probably stop sequence (</s>)
|
| 233 |
-
raise RuntimeError(
|
| 234 |
-
f"There is an unexpected difference between tokenizer and model tokens:\n"
|
| 235 |
-
f"context_tokens={context_tokens}\n"
|
| 236 |
-
f"response_tokens={response_tokens[:context_length]}"
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
def _check_model_logprobs_support(self):
|
| 240 |
-
"""
|
| 241 |
-
Verifies if the model supports returning log probabilities for input tokens.
|
| 242 |
-
This function sends a prompt to the model and checks whether the model's response
|
| 243 |
-
includes log probabilities for the input tokens. If log probabilities are not present,
|
| 244 |
-
it raises a `RuntimeError`, indicating that the model is not supported.
|
| 245 |
-
Raises:
|
| 246 |
-
RuntimeError: If the model does not return log probabilities for input tokens.
|
| 247 |
-
"""
|
| 248 |
-
tokens = self.model.generate_text(
|
| 249 |
-
prompt=["The best ice cream flavor is:"],
|
| 250 |
-
params=self.generate_params,
|
| 251 |
-
raw_response=True,
|
| 252 |
-
)[0]["results"][0]
|
| 253 |
-
if all(token.get("logprob", None) is None for token in tokens["input_tokens"]):
|
| 254 |
-
raise RuntimeError(
|
| 255 |
-
f"Model {self._model_id} is not supported: does not return logprobs for input tokens"
|
| 256 |
-
)
|
| 257 |
-
|
| 258 |
-
def _get_log_likelihood(
|
| 259 |
-
self,
|
| 260 |
-
input_tokens: List[Dict[str, float]],
|
| 261 |
-
context_tokens: List[Dict[str, float]],
|
| 262 |
-
) -> LogLikelihoodResult:
|
| 263 |
-
"""
|
| 264 |
-
Calculates the log likelihood of the generated tokens compared to the context tokens.
|
| 265 |
-
Args:
|
| 266 |
-
input_tokens (List[Dict[str, float]]): A List of token dictionaries, each containing
|
| 267 |
-
token information like `text` and `logprob`.
|
| 268 |
-
context_tokens (List[Dict[str, float]]): A List of token dictionaries representing
|
| 269 |
-
the input context.
|
| 270 |
-
Returns:
|
| 271 |
-
LogLikelihoodResult: An object containing the calculated log likelihood and a boolean
|
| 272 |
-
flag indicating if the tokens were generated greedily.
|
| 273 |
-
"""
|
| 274 |
-
|
| 275 |
-
response_tokens = [token["text"] for token in input_tokens]
|
| 276 |
-
context_length = len(context_tokens)
|
| 277 |
-
|
| 278 |
-
if self._has_stop_token(response_tokens, context_tokens):
|
| 279 |
-
context_length -= 1
|
| 280 |
-
|
| 281 |
-
return LogLikelihoodResult(
|
| 282 |
-
log_likelihood=sum(
|
| 283 |
-
token.get("logprob", 0) for token in input_tokens[context_length:]
|
| 284 |
-
),
|
| 285 |
-
is_greedy=all(
|
| 286 |
-
token["rank"] == 1 for token in input_tokens[context_length:]
|
| 287 |
-
),
|
| 288 |
-
)
|
| 289 |
-
|
| 290 |
-
def generate_until(self, requests: List[Instance]) -> List[str]:
|
| 291 |
-
"""
|
| 292 |
-
Generates text responses for a List of requests, with progress tracking and caching.
|
| 293 |
-
Args:
|
| 294 |
-
requests (List[Instance]): A List of instances, each containing a text input to be processed.
|
| 295 |
-
Returns:
|
| 296 |
-
List[str]: A List of generated responses.
|
| 297 |
-
"""
|
| 298 |
-
requests = [request.args for request in requests]
|
| 299 |
-
results = []
|
| 300 |
-
|
| 301 |
-
for request in tqdm(
|
| 302 |
-
requests,
|
| 303 |
-
desc="Running generate_until function ...",
|
| 304 |
-
):
|
| 305 |
-
context, continuation = request
|
| 306 |
-
try:
|
| 307 |
-
if isinstance(context, JsonChatStr):
|
| 308 |
-
context = json.loads(context.prompt)
|
| 309 |
-
response = self.model.chat(context, self.generate_params)
|
| 310 |
-
response = response["choices"][0]["message"]["content"]
|
| 311 |
-
else:
|
| 312 |
-
response = self.model.generate_text(context, self.generate_params)
|
| 313 |
-
except Exception as exp:
|
| 314 |
-
eval_logger.error("Error while generating text.")
|
| 315 |
-
raise exp
|
| 316 |
-
|
| 317 |
-
results.append(response)
|
| 318 |
-
self.cache_hook.add_partial(
|
| 319 |
-
"generate_until", (context, continuation), response
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
return results
|
| 323 |
-
|
| 324 |
-
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
|
| 325 |
-
"""
|
| 326 |
-
Args:
|
| 327 |
-
requests: Each request contains Instance.args : Tuple[str, str] containing:
|
| 328 |
-
1. an input string to the LM and
|
| 329 |
-
2. a target string on which the loglikelihood of the LM producing this target,
|
| 330 |
-
conditioned on the input, will be returned.
|
| 331 |
-
Returns:
|
| 332 |
-
Tuple (loglikelihood, is_greedy) for each request according to the input order:
|
| 333 |
-
loglikelihood: probability of generating the target string conditioned on the input
|
| 334 |
-
is_greedy: True if and only if the target string would be generated by greedy sampling from the LM
|
| 335 |
-
"""
|
| 336 |
-
try:
|
| 337 |
-
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
| 338 |
-
except ImportError:
|
| 339 |
-
raise ImportError(
|
| 340 |
-
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
|
| 341 |
-
)
|
| 342 |
-
self._check_model_logprobs_support()
|
| 343 |
-
generate_params = copy.copy(self.generate_params)
|
| 344 |
-
generate_params[GenParams.MAX_NEW_TOKENS] = 1
|
| 345 |
-
|
| 346 |
-
requests = [request.args for request in requests]
|
| 347 |
-
results: List[LogLikelihoodResult] = []
|
| 348 |
-
|
| 349 |
-
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
|
| 350 |
-
for request in tqdm(
|
| 351 |
-
requests,
|
| 352 |
-
desc="Running loglikelihood function ...",
|
| 353 |
-
):
|
| 354 |
-
context, continuation = request
|
| 355 |
-
try:
|
| 356 |
-
tokenized_context = self.model.tokenize(
|
| 357 |
-
prompt=context, return_tokens=True
|
| 358 |
-
)["result"]["tokens"]
|
| 359 |
-
except Exception as exp:
|
| 360 |
-
eval_logger.error("Error while model tokenize.")
|
| 361 |
-
raise exp
|
| 362 |
-
|
| 363 |
-
input_prompt = context + continuation
|
| 364 |
-
|
| 365 |
-
try:
|
| 366 |
-
response = self.model.generate_text(
|
| 367 |
-
prompt=input_prompt, params=generate_params, raw_response=True
|
| 368 |
-
)
|
| 369 |
-
except Exception as exp:
|
| 370 |
-
eval_logger.error("Error while model generate text.")
|
| 371 |
-
raise exp
|
| 372 |
-
|
| 373 |
-
log_likelihood_response = self._get_log_likelihood(
|
| 374 |
-
response["results"][0]["input_tokens"], tokenized_context
|
| 375 |
-
)
|
| 376 |
-
results.append(log_likelihood_response)
|
| 377 |
-
self.cache_hook.add_partial(
|
| 378 |
-
"loglikelihood",
|
| 379 |
-
(context, continuation),
|
| 380 |
-
(
|
| 381 |
-
log_likelihood_response.log_likelihood,
|
| 382 |
-
log_likelihood_response.is_greedy,
|
| 383 |
-
),
|
| 384 |
-
)
|
| 385 |
-
|
| 386 |
-
return cast(List[Tuple[float, bool]], results)
|
| 387 |
-
|
| 388 |
-
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
|
| 389 |
-
"""
|
| 390 |
-
Used to evaluate perplexity on a data distribution.
|
| 391 |
-
Args:
|
| 392 |
-
requests: Each request contains Instance.args : Tuple[str] containing an input string to the model whose
|
| 393 |
-
entire loglikelihood, conditioned on purely the EOT token, will be calculated.
|
| 394 |
-
Returns:
|
| 395 |
-
Tuple (loglikelihood,) for each request according to the input order:
|
| 396 |
-
loglikelihood: solely the probability of producing each piece of text given no starting input.
|
| 397 |
-
"""
|
| 398 |
-
try:
|
| 399 |
-
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
| 400 |
-
except ImportError:
|
| 401 |
-
raise ImportError(
|
| 402 |
-
"Could not import ibm_watsonx_ai: Please install lm_eval[ibm_watsonx_ai] package."
|
| 403 |
-
)
|
| 404 |
-
self._check_model_logprobs_support()
|
| 405 |
-
generate_params = copy.deepcopy(self.generate_params)
|
| 406 |
-
generate_params[GenParams.MAX_NEW_TOKENS] = 1
|
| 407 |
-
|
| 408 |
-
requests = [request.args for request in requests]
|
| 409 |
-
results: List[LogLikelihoodResult] = []
|
| 410 |
-
|
| 411 |
-
# Note: We're not using batching due to (current) indeterminism of loglikelihood values when sending batch of requests
|
| 412 |
-
for request in tqdm(
|
| 413 |
-
requests,
|
| 414 |
-
desc="Running loglikelihood_rolling function ...",
|
| 415 |
-
):
|
| 416 |
-
context, continuation = request
|
| 417 |
-
try:
|
| 418 |
-
response = self.model.generate_text(
|
| 419 |
-
prompt=context, params=generate_params, raw_response=True
|
| 420 |
-
)
|
| 421 |
-
except Exception as exp:
|
| 422 |
-
eval_logger.error("Error while model generate text.")
|
| 423 |
-
raise exp
|
| 424 |
-
|
| 425 |
-
log_likelihood_response = self._get_log_likelihood(
|
| 426 |
-
response["results"][0]["input_tokens"], []
|
| 427 |
-
)
|
| 428 |
-
results.append(log_likelihood_response)
|
| 429 |
-
self.cache_hook.add_partial(
|
| 430 |
-
"loglikelihood_rolling",
|
| 431 |
-
(context, continuation),
|
| 432 |
-
log_likelihood_response.log_likelihood,
|
| 433 |
-
)
|
| 434 |
-
|
| 435 |
-
return cast(List[Tuple[float, bool]], results)
|
| 436 |
-
|
| 437 |
-
@property
|
| 438 |
-
def tokenizer_name(self) -> str:
|
| 439 |
-
return ""
|
| 440 |
-
|
| 441 |
-
def apply_chat_template(
|
| 442 |
-
self, chat_history: List[Dict[str, str]]
|
| 443 |
-
) -> List[Dict[str, str]]:
|
| 444 |
-
# A hack similar from api_model to allow encoding for cache
|
| 445 |
-
return JsonChatStr(json.dumps(chat_history))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|