Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- scripts/yans/eval/lm-evaluation-harness/lm_eval/__init__.py +0 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/base.py +1051 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/evaluator.py +381 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/README.md +2 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__init__.py +0 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__pycache__/__init__.cpython-310.pyc +0 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__pycache__/evaluate.cpython-310.pyc +0 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__pycache__/jasquad.cpython-310.pyc +0 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/evaluate.py +121 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/jasquad.py +128 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/requirements.txt +1 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/metrics.py +286 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/prompts.py +33 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/suites/__init__.py +56 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/suites/configs/ja8.conf +33 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/coqa.py +178 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/hellaswag.py +77 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/lambada.py +108 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/lambada_multilingual.py +123 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/qa4mre.py +76 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/squad.py +219 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/superglue.py +490 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/translation.py +244 -0
- scripts/yans/eval/lm-evaluation-harness/lm_eval/utils.py +301 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/harness.jsquad-1.2.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/harness.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/result.json +59 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/result.jsquad-1.2.json +22 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/result.mgsm.json +0 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/harness.jsquad-1.2.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/harness.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/result.json +71 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/result.jsquad-1.2.json +22 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/result.mgsm.json +0 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/harness.jsquad-1.2.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/harness.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/result.json +71 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/result.jsquad-1.2.json +22 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/result.mgsm.json +0 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/harness.jsquad-1.2.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/harness.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/result.json +59 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/result.jsquad-1.2.json +22 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/harness.jsquad-1.2.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/harness.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/result.json +59 -0
- scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/result.jsquad-1.2.json +22 -0
- scripts/yans/eval/lm-evaluation-harness/models/llama/llama-13b/harness.sh +3 -0
- scripts/yans/eval/lm-evaluation-harness/models/llama/llama-13b/result.json +48 -0
- scripts/yans/eval/lm-evaluation-harness/models/llama/llama-30b/harness.sh +3 -0
scripts/yans/eval/lm-evaluation-harness/lm_eval/__init__.py
ADDED
|
File without changes
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/base.py
ADDED
|
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from typing import Iterable
|
| 4 |
+
import numpy as np
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import hashlib
|
| 10 |
+
import datasets
|
| 11 |
+
from sqlitedict import SqliteDict
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
|
| 17 |
+
from lm_eval.metrics import balanced_mean, matthews_corrcoef, macro_f1
|
| 18 |
+
from lm_eval import utils
|
| 19 |
+
from abc import abstractmethod
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LM(abc.ABC):
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.cache_hook = CacheHook(None)
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def loglikelihood(self, requests):
|
| 28 |
+
"""Compute log-likelihood of generating a continuation from a context.
|
| 29 |
+
Downstream tasks should attempt to use loglikelihood instead of other
|
| 30 |
+
LM calls whenever possible.
|
| 31 |
+
|
| 32 |
+
:param requests: list
|
| 33 |
+
A list of pairs (context, continuation)
|
| 34 |
+
context: str
|
| 35 |
+
Context string. Implementations of LM must be able to handle an
|
| 36 |
+
empty context string.
|
| 37 |
+
continuation: str
|
| 38 |
+
The continuation over which log likelihood will be calculated. If
|
| 39 |
+
there is a word boundary, the space should be in the continuation.
|
| 40 |
+
For example, context="hello" continuation=" world" is correct.
|
| 41 |
+
:return: list
|
| 42 |
+
A list of pairs (logprob, isgreedy)
|
| 43 |
+
logprob: float
|
| 44 |
+
The log probability of `continuation`
|
| 45 |
+
isgreedy:
|
| 46 |
+
Whether `continuation` would be generated by greedy sampling from `context`
|
| 47 |
+
"""
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
@abstractmethod
|
| 51 |
+
def loglikelihood_rolling(self, requests):
|
| 52 |
+
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
|
| 53 |
+
- We will use the full max context length of the model.
|
| 54 |
+
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
|
| 55 |
+
the max context length.
|
| 56 |
+
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
|
| 57 |
+
which may simply concatenate multiple documents together.
|
| 58 |
+
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
|
| 59 |
+
multiple chunks, the last input will still a full-sized context.
|
| 60 |
+
Example:
|
| 61 |
+
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
|
| 62 |
+
Prefix: EOT
|
| 63 |
+
Max context length: 4
|
| 64 |
+
Resulting input/prediction pairs:
|
| 65 |
+
|
| 66 |
+
INPUT: EOT 0 1 2
|
| 67 |
+
PRED: 0 1 2 3
|
| 68 |
+
|
| 69 |
+
INPUT: 3 4 5 6
|
| 70 |
+
PRED: 4 5 6 7
|
| 71 |
+
|
| 72 |
+
INPUT: 5 6 7 8
|
| 73 |
+
PRED: 8 9
|
| 74 |
+
|
| 75 |
+
Observe that:
|
| 76 |
+
1. Each token is predicted exactly once
|
| 77 |
+
2. For the last pair, we provide the full context, but only score the last two tokens
|
| 78 |
+
|
| 79 |
+
:param requests: list
|
| 80 |
+
A list of strings
|
| 81 |
+
string: str
|
| 82 |
+
String for which we are computing per-toke loglikelihood
|
| 83 |
+
:return: list
|
| 84 |
+
A list of pairs (logprob, isgreedy)
|
| 85 |
+
logprob: float
|
| 86 |
+
The log probability of `continuation`
|
| 87 |
+
isgreedy:
|
| 88 |
+
Whether `continuation` would be generated by greedy sampling from `context`
|
| 89 |
+
"""
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
# TODO: Add an optional max length
|
| 93 |
+
@abstractmethod
|
| 94 |
+
def greedy_until(self, requests):
|
| 95 |
+
"""Generate greedily until a stopping sequence
|
| 96 |
+
|
| 97 |
+
:param requests: list
|
| 98 |
+
A list of pairs (context, until) or (context, until, max_num_tokens)
|
| 99 |
+
context: str
|
| 100 |
+
Context string
|
| 101 |
+
until: [str]
|
| 102 |
+
The string sequences to generate until. These string sequences
|
| 103 |
+
may each span across multiple tokens, or may be part of one token.
|
| 104 |
+
(optional) max_num_tokens: int
|
| 105 |
+
Indicate the max length of the generation
|
| 106 |
+
:return: list
|
| 107 |
+
A list of strings continuation
|
| 108 |
+
continuation: str
|
| 109 |
+
The generated continuation.
|
| 110 |
+
"""
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def create_from_arg_string(cls, arg_string, additional_config=None):
|
| 115 |
+
additional_config = {} if additional_config is None else additional_config
|
| 116 |
+
args = utils.simple_parse_args_string(arg_string)
|
| 117 |
+
args2 = {k: v for k, v in additional_config.items() if v is not None}
|
| 118 |
+
return cls(**args, **args2)
|
| 119 |
+
|
| 120 |
+
def set_cache_hook(self, cache_hook):
|
| 121 |
+
self.cache_hook = cache_hook
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class BaseLM(LM):
|
| 125 |
+
@property
|
| 126 |
+
@abstractmethod
|
| 127 |
+
def eot_token_id(self):
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
@abstractmethod
|
| 132 |
+
def max_length(self):
|
| 133 |
+
pass
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
@abstractmethod
|
| 137 |
+
def max_gen_toks(self):
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
@abstractmethod
|
| 142 |
+
def batch_size(self):
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
@abstractmethod
|
| 147 |
+
def device(self):
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
@abstractmethod
|
| 151 |
+
def tok_encode(self, string: str):
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
@abstractmethod
|
| 155 |
+
def tok_decode(self, tokens: Iterable[int]):
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
@abstractmethod
|
| 159 |
+
def _model_generate(self, context, max_length, eos_token_id):
|
| 160 |
+
pass
|
| 161 |
+
|
| 162 |
+
@abstractmethod
|
| 163 |
+
def _model_call(self, inps):
|
| 164 |
+
"""
|
| 165 |
+
inps: a torch tensor of shape [batch, sequence]
|
| 166 |
+
the size of sequence may vary from call to call
|
| 167 |
+
|
| 168 |
+
returns: a torch tensor of shape [batch, sequence, vocab] with the
|
| 169 |
+
logits returned from the model
|
| 170 |
+
"""
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks, batch_size, device, max_length.
|
| 174 |
+
# TODO: enforce this somehow
|
| 175 |
+
|
| 176 |
+
def loglikelihood(self, requests):
|
| 177 |
+
new_reqs = []
|
| 178 |
+
for context, continuation in requests:
|
| 179 |
+
if context == "":
|
| 180 |
+
# end of text as context
|
| 181 |
+
context_enc = [self.eot_token_id]
|
| 182 |
+
else:
|
| 183 |
+
context_enc = self.tok_encode(context)
|
| 184 |
+
if continuation == "__lasttoken__":
|
| 185 |
+
# take last token from context
|
| 186 |
+
continuation_enc = [context_enc[-1]]
|
| 187 |
+
context_enc = context_enc[:-1]
|
| 188 |
+
else:
|
| 189 |
+
continuation_enc = self.tok_encode(continuation)
|
| 190 |
+
|
| 191 |
+
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
| 192 |
+
|
| 193 |
+
return self._loglikelihood_tokens(new_reqs)
|
| 194 |
+
|
| 195 |
+
def loglikelihood_rolling(self, requests):
|
| 196 |
+
# TODO: Implement caching once we've confirmed the perplexity implementation
|
| 197 |
+
# TODO: automatic batch size detection for vectorization
|
| 198 |
+
|
| 199 |
+
loglikelihoods = []
|
| 200 |
+
for (string,) in tqdm(requests):
|
| 201 |
+
rolling_token_windows = list(
|
| 202 |
+
map(
|
| 203 |
+
utils.make_disjoint_window,
|
| 204 |
+
utils.get_rolling_token_windows(
|
| 205 |
+
token_list=self.tok_encode(string),
|
| 206 |
+
prefix_token=self.eot_token_id,
|
| 207 |
+
max_seq_len=self.max_length,
|
| 208 |
+
context_len=1,
|
| 209 |
+
),
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
|
| 214 |
+
|
| 215 |
+
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
|
| 216 |
+
# that
|
| 217 |
+
string_nll = self._loglikelihood_tokens(
|
| 218 |
+
rolling_token_windows, disable_tqdm=True
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# discard is_greedy
|
| 222 |
+
string_nll = [x[0] for x in string_nll]
|
| 223 |
+
|
| 224 |
+
string_nll = sum(string_nll)
|
| 225 |
+
loglikelihoods.append(string_nll)
|
| 226 |
+
|
| 227 |
+
return loglikelihoods
|
| 228 |
+
|
| 229 |
+
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
|
| 230 |
+
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
|
| 231 |
+
res = []
|
| 232 |
+
|
| 233 |
+
def _collate(x):
|
| 234 |
+
# the negative sign on len(toks) sorts descending - this has a few advantages:
|
| 235 |
+
# - time estimates will always be over not underestimates, which is more useful for planning
|
| 236 |
+
# - to know the size of a batch when going through the list, you know the first one is always the batch
|
| 237 |
+
# padded context length. this is useful to simplify the batching logic and more importantly to make
|
| 238 |
+
# automatic adaptive batches much much easier to implement
|
| 239 |
+
# - any OOMs will happen right away rather than near the end
|
| 240 |
+
|
| 241 |
+
toks = x[1] + x[2]
|
| 242 |
+
return -len(toks), tuple(toks)
|
| 243 |
+
|
| 244 |
+
# TODO: automatic (variable) batch size detection for vectorization
|
| 245 |
+
re_ord = utils.Reorderer(requests, _collate)
|
| 246 |
+
for chunk in utils.chunks(
|
| 247 |
+
tqdm(re_ord.get_reordered(), disable=disable_tqdm), self.batch_size
|
| 248 |
+
):
|
| 249 |
+
inps = []
|
| 250 |
+
cont_toks_list = []
|
| 251 |
+
inplens = []
|
| 252 |
+
|
| 253 |
+
padding_length = None
|
| 254 |
+
|
| 255 |
+
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
|
| 256 |
+
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
|
| 257 |
+
# again because vectorizing is annoying
|
| 258 |
+
|
| 259 |
+
for _, context_enc, continuation_enc in chunk:
|
| 260 |
+
# sanity check
|
| 261 |
+
assert len(context_enc) > 0
|
| 262 |
+
assert len(continuation_enc) > 0
|
| 263 |
+
assert len(continuation_enc) <= self.max_length
|
| 264 |
+
|
| 265 |
+
# how this all works:
|
| 266 |
+
# CTX CONT
|
| 267 |
+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
|
| 268 |
+
# gpt2 \ \
|
| 269 |
+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
|
| 270 |
+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
|
| 271 |
+
|
| 272 |
+
# when too long to fit in context, truncate from the left
|
| 273 |
+
inp = torch.tensor(
|
| 274 |
+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
|
| 275 |
+
dtype=torch.long,
|
| 276 |
+
).to(self.device)
|
| 277 |
+
(inplen,) = inp.shape
|
| 278 |
+
|
| 279 |
+
cont = continuation_enc
|
| 280 |
+
|
| 281 |
+
# since in _collate we make sure length is descending, the longest is always the first one.
|
| 282 |
+
padding_length = (
|
| 283 |
+
padding_length if padding_length is not None else inplen
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# pad length from seq to padding_length
|
| 287 |
+
inp = torch.cat(
|
| 288 |
+
[
|
| 289 |
+
inp, # [seq]
|
| 290 |
+
torch.zeros(padding_length - inplen, dtype=torch.long).to(
|
| 291 |
+
inp.device
|
| 292 |
+
), # [padding_length - seq]
|
| 293 |
+
],
|
| 294 |
+
dim=0,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
inps.append(inp.unsqueeze(0)) # [1, padding_length]
|
| 298 |
+
cont_toks_list.append(cont)
|
| 299 |
+
inplens.append(inplen)
|
| 300 |
+
|
| 301 |
+
batched_inps = torch.cat(inps, dim=0) # [batch, padding_length
|
| 302 |
+
multi_logits = F.log_softmax(
|
| 303 |
+
self._model_call(batched_inps), dim=-1
|
| 304 |
+
).cpu() # [batch, padding_length, vocab]
|
| 305 |
+
|
| 306 |
+
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(
|
| 307 |
+
chunk, multi_logits, inps, inplens, cont_toks_list
|
| 308 |
+
):
|
| 309 |
+
# Slice to original seq length
|
| 310 |
+
contlen = len(cont_toks)
|
| 311 |
+
logits = logits[inplen - contlen : inplen].unsqueeze(
|
| 312 |
+
0
|
| 313 |
+
) # [1, seq, vocab]
|
| 314 |
+
|
| 315 |
+
# Check if per-token argmax is exactly equal to continuation
|
| 316 |
+
greedy_tokens = logits.argmax(dim=-1)
|
| 317 |
+
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(
|
| 318 |
+
0
|
| 319 |
+
) # [1, seq]
|
| 320 |
+
max_equal = (greedy_tokens == cont_toks).all()
|
| 321 |
+
|
| 322 |
+
# Obtain log-probs at the corresponding continuation token indices
|
| 323 |
+
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
|
| 324 |
+
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
|
| 325 |
+
-1
|
| 326 |
+
) # [1, seq]
|
| 327 |
+
|
| 328 |
+
# Answer: (log prob, is-exact-match)
|
| 329 |
+
answer = (float(logits.sum()), bool(max_equal))
|
| 330 |
+
|
| 331 |
+
# partial caching
|
| 332 |
+
if cache_key is not None:
|
| 333 |
+
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
|
| 334 |
+
|
| 335 |
+
res.append(answer)
|
| 336 |
+
|
| 337 |
+
return re_ord.get_original(res)
|
| 338 |
+
|
| 339 |
+
def greedy_until(self, requests):
|
| 340 |
+
# TODO: implement fully general `until` that handles until that are
|
| 341 |
+
# multiple tokens or that span multiple tokens correctly
|
| 342 |
+
|
| 343 |
+
# TODO: extract to TokenizedLM?
|
| 344 |
+
res = []
|
| 345 |
+
|
| 346 |
+
def _collate(x):
|
| 347 |
+
toks = self.tok_encode(x[0])
|
| 348 |
+
return len(toks), x[0]
|
| 349 |
+
|
| 350 |
+
re_ord = utils.Reorderer(requests, _collate)
|
| 351 |
+
for req in tqdm(re_ord.get_reordered()):
|
| 352 |
+
if len(req) == 2:
|
| 353 |
+
context, until = req
|
| 354 |
+
max_gen_toks = self.max_gen_toks
|
| 355 |
+
elif len(req) == 3:
|
| 356 |
+
context, until, max_num_tokens = req
|
| 357 |
+
max_gen_toks = max_num_tokens
|
| 358 |
+
else:
|
| 359 |
+
raise NotImplementedError
|
| 360 |
+
if isinstance(until, str):
|
| 361 |
+
until = [until]
|
| 362 |
+
# (primary_until,) = self.tok_encode(until[0])
|
| 363 |
+
primary_until = self.tok_encode(until[0])
|
| 364 |
+
if len(primary_until) == 0:
|
| 365 |
+
primary_until = self.tokenizer.eos_token_id
|
| 366 |
+
else:
|
| 367 |
+
primary_until = primary_until[-1]
|
| 368 |
+
context_enc = torch.tensor(
|
| 369 |
+
[self.tok_encode(context)[max_gen_toks - self.max_length :]]
|
| 370 |
+
).to(self.device)
|
| 371 |
+
|
| 372 |
+
cont = self._model_generate(
|
| 373 |
+
context_enc, context_enc.shape[1] + max_gen_toks, primary_until
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
|
| 377 |
+
|
| 378 |
+
for term in until:
|
| 379 |
+
s = s.split(term)[0]
|
| 380 |
+
|
| 381 |
+
# partial caching
|
| 382 |
+
self.cache_hook.add_partial("greedy_until", (context, until), s)
|
| 383 |
+
|
| 384 |
+
res.append(s)
|
| 385 |
+
|
| 386 |
+
return re_ord.get_original(res)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class Task(abc.ABC):
|
| 390 |
+
"""A task represents an entire benchmark including its dataset, problems,
|
| 391 |
+
answers, and evaluation methods. See BoolQ for a simple example implementation
|
| 392 |
+
|
| 393 |
+
A `doc` can be any python object which represents one instance of evaluation.
|
| 394 |
+
This is usually a dictionary e.g.
|
| 395 |
+
{"question": ..., "answer": ...} or
|
| 396 |
+
{"question": ..., question, answer)
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
|
| 400 |
+
# or a path to a custom `datasets` loading script.
|
| 401 |
+
DATASET_PATH: str = None
|
| 402 |
+
|
| 403 |
+
# The name of a subset within `DATASET_PATH`.
|
| 404 |
+
DATASET_NAME: str = None
|
| 405 |
+
# Load tokenizer inside Task class
|
| 406 |
+
LOAD_TOKENIZER: bool = False
|
| 407 |
+
|
| 408 |
+
def __init__(self, data_dir=None, cache_dir=None, download_mode=None):
|
| 409 |
+
"""
|
| 410 |
+
:param data_dir: str
|
| 411 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 412 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 413 |
+
the dataset is not publicly accessible).
|
| 414 |
+
:param cache_dir: str
|
| 415 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 416 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 417 |
+
`~/.cache/huggingface/datasets`
|
| 418 |
+
NOTE: You can change the cache location globally for a given process
|
| 419 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 420 |
+
to another directory:
|
| 421 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 422 |
+
:param download_mode: datasets.DownloadMode
|
| 423 |
+
How to treat pre-existing `Task` downloads and data.
|
| 424 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 425 |
+
Reuse download and reuse dataset.
|
| 426 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 427 |
+
Reuse download with fresh dataset.
|
| 428 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 429 |
+
Fresh download and fresh dataset.
|
| 430 |
+
"""
|
| 431 |
+
self.download(data_dir, cache_dir, download_mode)
|
| 432 |
+
self._training_docs = None
|
| 433 |
+
self._fewshot_docs = None
|
| 434 |
+
self._target_to_docs = None
|
| 435 |
+
self._target_to_ratio = None
|
| 436 |
+
|
| 437 |
+
def download(self, data_dir=None, cache_dir=None, download_mode=None):
|
| 438 |
+
"""Downloads and returns the task dataset.
|
| 439 |
+
Override this method to download the dataset from a custom API.
|
| 440 |
+
|
| 441 |
+
:param data_dir: str
|
| 442 |
+
Stores the path to a local folder containing the `Task`'s data files.
|
| 443 |
+
Use this to specify the path to manually downloaded data (usually when
|
| 444 |
+
the dataset is not publicly accessible).
|
| 445 |
+
:param cache_dir: str
|
| 446 |
+
The directory to read/write the `Task` dataset. This follows the
|
| 447 |
+
HuggingFace `datasets` API with the default cache directory located at:
|
| 448 |
+
`~/.cache/huggingface/datasets`
|
| 449 |
+
NOTE: You can change the cache location globally for a given process
|
| 450 |
+
by setting the shell environment variable, `HF_DATASETS_CACHE`,
|
| 451 |
+
to another directory:
|
| 452 |
+
`export HF_DATASETS_CACHE="/path/to/another/directory"`
|
| 453 |
+
:param download_mode: datasets.DownloadMode
|
| 454 |
+
How to treat pre-existing `Task` downloads and data.
|
| 455 |
+
- `datasets.DownloadMode.REUSE_DATASET_IF_EXISTS`
|
| 456 |
+
Reuse download and reuse dataset.
|
| 457 |
+
- `datasets.DownloadMode.REUSE_CACHE_IF_EXISTS`
|
| 458 |
+
Reuse download with fresh dataset.
|
| 459 |
+
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
|
| 460 |
+
Fresh download and fresh dataset.
|
| 461 |
+
"""
|
| 462 |
+
self.dataset = datasets.load_dataset(
|
| 463 |
+
path=self.DATASET_PATH,
|
| 464 |
+
name=self.DATASET_NAME,
|
| 465 |
+
data_dir=data_dir,
|
| 466 |
+
cache_dir=cache_dir,
|
| 467 |
+
download_mode=download_mode,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
def should_decontaminate(self):
|
| 471 |
+
"""Whether this task supports decontamination against model training set."""
|
| 472 |
+
return False
|
| 473 |
+
|
| 474 |
+
@abstractmethod
|
| 475 |
+
def has_training_docs(self):
|
| 476 |
+
"""Whether the task has a training set"""
|
| 477 |
+
pass
|
| 478 |
+
|
| 479 |
+
@abstractmethod
|
| 480 |
+
def has_validation_docs(self):
|
| 481 |
+
"""Whether the task has a validation set"""
|
| 482 |
+
pass
|
| 483 |
+
|
| 484 |
+
@abstractmethod
|
| 485 |
+
def has_test_docs(self):
|
| 486 |
+
"""Whether the task has a test set"""
|
| 487 |
+
pass
|
| 488 |
+
|
| 489 |
+
def training_docs(self):
|
| 490 |
+
"""
|
| 491 |
+
:return: Iterable[obj]
|
| 492 |
+
A iterable of any object, that doc_to_text can handle
|
| 493 |
+
"""
|
| 494 |
+
return []
|
| 495 |
+
|
| 496 |
+
def validation_docs(self):
|
| 497 |
+
"""
|
| 498 |
+
:return: Iterable[obj]
|
| 499 |
+
A iterable of any object, that doc_to_text can handle
|
| 500 |
+
"""
|
| 501 |
+
return []
|
| 502 |
+
|
| 503 |
+
def test_docs(self):
|
| 504 |
+
"""
|
| 505 |
+
:return: Iterable[obj]
|
| 506 |
+
A iterable of any object, that doc_to_text can handle
|
| 507 |
+
"""
|
| 508 |
+
return []
|
| 509 |
+
|
| 510 |
+
def _process_doc(self, doc):
|
| 511 |
+
"""
|
| 512 |
+
Override this to process (detokenize, strip, replace, etc.) individual
|
| 513 |
+
documents. This can be used in a map over documents of a data split.
|
| 514 |
+
E.g. `map(self._process_doc, self.dataset["validation"])`
|
| 515 |
+
|
| 516 |
+
:return: dict
|
| 517 |
+
The processed version of the specified `doc`.
|
| 518 |
+
"""
|
| 519 |
+
return doc
|
| 520 |
+
|
| 521 |
+
def fewshot_examples(self, k, rnd, stratified=False):
|
| 522 |
+
"""Returns few shot examples from training docs"""
|
| 523 |
+
if self._training_docs is None:
|
| 524 |
+
self._training_docs = list(self.training_docs())
|
| 525 |
+
|
| 526 |
+
if stratified:
|
| 527 |
+
return self._stratified_fewshot_examples(self._training_docs, k, rnd)
|
| 528 |
+
else:
|
| 529 |
+
return rnd.sample(self._training_docs, k)
|
| 530 |
+
|
| 531 |
+
def _stratified_fewshot_examples(self, docs, k, rnd):
|
| 532 |
+
"""Returns few shot examples from `docs` with stratified sampling,
|
| 533 |
+
using the target from `self.doc_to_target` as the stratum.
|
| 534 |
+
|
| 535 |
+
WARNING: in order to speed up computation, this method caches the following
|
| 536 |
+
based on `docs`:
|
| 537 |
+
- `self._target_to_docs`, which stores a mapping from target to docs, and
|
| 538 |
+
- `self._target_to_ratio`, which stores a mapping from target to the ratio of docs
|
| 539 |
+
Thus, `docs` MUST be constant across different method calls.
|
| 540 |
+
This assumption should generally hold true, since for a given task `docs`
|
| 541 |
+
will typically be either one of:
|
| 542 |
+
- `self._training_docs` if the dataset for the task has training data, or
|
| 543 |
+
- `self._fewshot_docs` if the dataset for the task does not have any training data
|
| 544 |
+
"""
|
| 545 |
+
if self._target_to_docs is None or self._target_to_ratio is None:
|
| 546 |
+
self._target_to_docs = defaultdict(list)
|
| 547 |
+
for doc in docs:
|
| 548 |
+
target = self.doc_to_target(doc)
|
| 549 |
+
self._target_to_docs[target].append(doc)
|
| 550 |
+
|
| 551 |
+
self._target_to_ratio = {
|
| 552 |
+
target: len(_docs) / len(docs)
|
| 553 |
+
for target, _docs in self._target_to_docs.items()
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
# `k` should generally be constant across different method calls
|
| 557 |
+
# (as the number of few-shot is typically fixed for a given task),
|
| 558 |
+
# but this may not be guaranteed, so calculate the number of sample
|
| 559 |
+
# for each target per method call
|
| 560 |
+
target_to_num_samples = {
|
| 561 |
+
target: int(ratio * k) for target, ratio in self._target_to_ratio.items()
|
| 562 |
+
}
|
| 563 |
+
# Handle any rounding discrepancies by adjusting the counts
|
| 564 |
+
remaining_samples = k - sum(target_to_num_samples.values())
|
| 565 |
+
if remaining_samples > 0:
|
| 566 |
+
for _ in range(remaining_samples):
|
| 567 |
+
# Increment the min value
|
| 568 |
+
target = min(target_to_num_samples, key=target_to_num_samples.get)
|
| 569 |
+
target_to_num_samples[target] += 1
|
| 570 |
+
|
| 571 |
+
samples = []
|
| 572 |
+
for target, num_samples in target_to_num_samples.items():
|
| 573 |
+
samples.extend(rnd.sample(self._target_to_docs[target], num_samples))
|
| 574 |
+
# Randomly shuffle the samples to prevent potential biases
|
| 575 |
+
# that may arise from a fixed ordering of the targets
|
| 576 |
+
rnd.shuffle(samples)
|
| 577 |
+
return samples
|
| 578 |
+
|
| 579 |
+
def doc_to_decontamination_query(self, doc):
|
| 580 |
+
print(
|
| 581 |
+
"Override doc_to_decontamination_query with document specific decontamination query."
|
| 582 |
+
)
|
| 583 |
+
assert False
|
| 584 |
+
|
| 585 |
+
@abstractmethod
|
| 586 |
+
def doc_to_text(self, doc):
|
| 587 |
+
pass
|
| 588 |
+
|
| 589 |
+
@abstractmethod
|
| 590 |
+
def doc_to_target(self, doc):
|
| 591 |
+
pass
|
| 592 |
+
|
| 593 |
+
@abstractmethod
|
| 594 |
+
def construct_requests(self, doc, ctx):
|
| 595 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 596 |
+
Requests which will be sent to the LM.
|
| 597 |
+
|
| 598 |
+
:param doc:
|
| 599 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 600 |
+
:param ctx: str
|
| 601 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 602 |
+
language description, as well as the few shot examples, and the question
|
| 603 |
+
part of the document for `doc`.
|
| 604 |
+
"""
|
| 605 |
+
pass
|
| 606 |
+
|
| 607 |
+
@abstractmethod
|
| 608 |
+
def process_results(self, doc, results):
|
| 609 |
+
"""Take a single document and the LM results and evaluates, returning a
|
| 610 |
+
dict where keys are the names of submetrics and values are the values of
|
| 611 |
+
the metric for that one document
|
| 612 |
+
|
| 613 |
+
:param doc:
|
| 614 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 615 |
+
:param results:
|
| 616 |
+
The results of the requests created in construct_requests.
|
| 617 |
+
"""
|
| 618 |
+
pass
|
| 619 |
+
|
| 620 |
+
@abstractmethod
|
| 621 |
+
def aggregation(self):
|
| 622 |
+
"""
|
| 623 |
+
:returns: {str: [metric_score] -> float}
|
| 624 |
+
A dictionary where keys are the names of submetrics and values are
|
| 625 |
+
functions that aggregate a list of metric scores
|
| 626 |
+
"""
|
| 627 |
+
pass
|
| 628 |
+
|
| 629 |
+
@abstractmethod
|
| 630 |
+
def higher_is_better(self):
|
| 631 |
+
"""
|
| 632 |
+
:returns: {str: bool}
|
| 633 |
+
A dictionary where keys are the names of submetrics and values are
|
| 634 |
+
whether a higher value of the submetric is better
|
| 635 |
+
"""
|
| 636 |
+
pass
|
| 637 |
+
|
| 638 |
+
def fewshot_description(self):
|
| 639 |
+
import warnings
|
| 640 |
+
|
| 641 |
+
warnings.warn(
|
| 642 |
+
"`fewshot_description` will be removed in futures versions. Pass "
|
| 643 |
+
"any custom descriptions to the `evaluate` function instead.",
|
| 644 |
+
DeprecationWarning,
|
| 645 |
+
)
|
| 646 |
+
return ""
|
| 647 |
+
|
| 648 |
+
@utils.positional_deprecated
|
| 649 |
+
def fewshot_context(
|
| 650 |
+
self,
|
| 651 |
+
doc,
|
| 652 |
+
num_fewshot,
|
| 653 |
+
provide_description=None,
|
| 654 |
+
rnd=None,
|
| 655 |
+
description=None,
|
| 656 |
+
stratified=False,
|
| 657 |
+
):
|
| 658 |
+
"""Returns a fewshot context string that is made up of a prepended description
|
| 659 |
+
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
|
| 660 |
+
|
| 661 |
+
:param doc: str
|
| 662 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 663 |
+
:param num_fewshot: int
|
| 664 |
+
The number of fewshot examples to provide in the returned context string.
|
| 665 |
+
:param provide_description: bool
|
| 666 |
+
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
|
| 667 |
+
:param rnd: random.Random
|
| 668 |
+
The pseudo-random number generator used to randomly sample examples.
|
| 669 |
+
WARNING: This is currently a required arg although it's optionalized with a default `None`.
|
| 670 |
+
:param description: str
|
| 671 |
+
The task's description that will be prepended to the fewshot examples.
|
| 672 |
+
:param stratified: bool
|
| 673 |
+
When true, does stratified sampling, using the target from `self.doc_to_target` as the stratum.
|
| 674 |
+
:returns: str
|
| 675 |
+
The fewshot context.
|
| 676 |
+
"""
|
| 677 |
+
assert (
|
| 678 |
+
rnd is not None
|
| 679 |
+
), "A `random.Random` generator argument must be provided to `rnd`"
|
| 680 |
+
assert not provide_description, (
|
| 681 |
+
"The `provide_description` arg will be removed in future versions. To prepend "
|
| 682 |
+
"a custom description to the context, supply the corresponding string via the "
|
| 683 |
+
"`description` arg."
|
| 684 |
+
)
|
| 685 |
+
if provide_description is not None:
|
| 686 |
+
# nudge people to not specify it at all
|
| 687 |
+
print(
|
| 688 |
+
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
|
| 689 |
+
)
|
| 690 |
+
if hasattr(self, "FEWSHOT_SEP"):
|
| 691 |
+
FEWSHOT_SEP = self.FEWSHOT_SEP
|
| 692 |
+
elif hasattr(self, "SEP"):
|
| 693 |
+
FEWSHOT_SEP = f"{self.SEP}{self.SEP}"
|
| 694 |
+
else:
|
| 695 |
+
FEWSHOT_SEP = "\n\n"
|
| 696 |
+
|
| 697 |
+
if description:
|
| 698 |
+
description += FEWSHOT_SEP
|
| 699 |
+
elif hasattr(self, "DESCRIPTION"):
|
| 700 |
+
description = self.DESCRIPTION
|
| 701 |
+
else:
|
| 702 |
+
description = ""
|
| 703 |
+
|
| 704 |
+
if num_fewshot == 0:
|
| 705 |
+
labeled_examples = ""
|
| 706 |
+
else:
|
| 707 |
+
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
|
| 708 |
+
if self.has_training_docs():
|
| 709 |
+
fewshotex = self.fewshot_examples(
|
| 710 |
+
k=num_fewshot, rnd=rnd, stratified=stratified
|
| 711 |
+
)
|
| 712 |
+
else:
|
| 713 |
+
if self._fewshot_docs is None:
|
| 714 |
+
self._fewshot_docs = list(
|
| 715 |
+
self.validation_docs()
|
| 716 |
+
if self.has_validation_docs()
|
| 717 |
+
else self.test_docs()
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
if stratified:
|
| 721 |
+
fewshotex = self._stratified_fewshot_examples(
|
| 722 |
+
self._fewshot_docs, num_fewshot + 1, rnd=rnd
|
| 723 |
+
)
|
| 724 |
+
else:
|
| 725 |
+
fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
|
| 726 |
+
|
| 727 |
+
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
|
| 728 |
+
fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
|
| 729 |
+
|
| 730 |
+
labeled_examples = (
|
| 731 |
+
FEWSHOT_SEP.join(
|
| 732 |
+
[
|
| 733 |
+
self.doc_to_text(doc) + self.doc_to_target(doc)
|
| 734 |
+
for doc in fewshotex
|
| 735 |
+
]
|
| 736 |
+
)
|
| 737 |
+
+ FEWSHOT_SEP
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
example = self.doc_to_text(doc)
|
| 741 |
+
return description + labeled_examples + example
|
| 742 |
+
|
| 743 |
+
def set_tokenizer(self, tokenizer):
|
| 744 |
+
self.tokenizer = tokenizer
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
class MultipleChoiceTask(Task):
|
| 748 |
+
def doc_to_target(self, doc):
|
| 749 |
+
return " " + doc["choices"][doc["gold"]]
|
| 750 |
+
|
| 751 |
+
def construct_requests(self, doc, ctx):
|
| 752 |
+
lls = [
|
| 753 |
+
rf.loglikelihood(ctx, " {}".format(choice))[0] for choice in doc["choices"]
|
| 754 |
+
]
|
| 755 |
+
|
| 756 |
+
return lls
|
| 757 |
+
|
| 758 |
+
def process_results(self, doc, results):
|
| 759 |
+
gold = doc["gold"]
|
| 760 |
+
|
| 761 |
+
acc = 1.0 if np.argmax(results) == gold else 0.0
|
| 762 |
+
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
| 763 |
+
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
| 764 |
+
|
| 765 |
+
return {
|
| 766 |
+
"acc": acc,
|
| 767 |
+
"acc_norm": acc_norm,
|
| 768 |
+
"details": {
|
| 769 |
+
"scores": results,
|
| 770 |
+
},
|
| 771 |
+
}
|
| 772 |
+
|
| 773 |
+
def higher_is_better(self):
|
| 774 |
+
return {
|
| 775 |
+
"acc": True,
|
| 776 |
+
"acc_norm": True,
|
| 777 |
+
}
|
| 778 |
+
|
| 779 |
+
def aggregation(self):
|
| 780 |
+
return {
|
| 781 |
+
"acc": mean,
|
| 782 |
+
"acc_norm": mean,
|
| 783 |
+
}
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
class BalancedMultipleChoiceTask(MultipleChoiceTask):
|
| 787 |
+
"""A task where the choices are the same every time, and accuracy should be
|
| 788 |
+
calculated separately for each class.
|
| 789 |
+
|
| 790 |
+
Originally created for marc-ja, which is severely imbalanced, though also
|
| 791 |
+
useful with less weird datasets. Not suitable for datasets where the choices
|
| 792 |
+
change for every question.
|
| 793 |
+
"""
|
| 794 |
+
|
| 795 |
+
def process_results(self, doc, results):
|
| 796 |
+
gold = doc["gold"]
|
| 797 |
+
|
| 798 |
+
# This isn't very clean, but it may be the best we can do since lm ops
|
| 799 |
+
# are submitted as an iterator for batching
|
| 800 |
+
response = None
|
| 801 |
+
if isinstance(results[-1], str):
|
| 802 |
+
response = results.pop()
|
| 803 |
+
|
| 804 |
+
pred = np.argmax(results)
|
| 805 |
+
acc = 1.0 if np.argmax(results) == gold else 0.0
|
| 806 |
+
completion_len = np.array([float(len(i)) for i in doc["choices"]])
|
| 807 |
+
acc_norm = 1.0 if np.argmax(results / completion_len) == gold else 0.0
|
| 808 |
+
|
| 809 |
+
return {
|
| 810 |
+
"acc": acc,
|
| 811 |
+
"acc_norm": acc_norm,
|
| 812 |
+
"balanced_acc": (acc, gold),
|
| 813 |
+
"mcc": (gold, pred),
|
| 814 |
+
"macro_f1": (gold, pred),
|
| 815 |
+
"details": {
|
| 816 |
+
"question": self.doc_to_text(doc),
|
| 817 |
+
"response": response,
|
| 818 |
+
"scores": results,
|
| 819 |
+
},
|
| 820 |
+
}
|
| 821 |
+
|
| 822 |
+
def higher_is_better(self):
|
| 823 |
+
return {
|
| 824 |
+
"acc": True,
|
| 825 |
+
"acc_norm": True,
|
| 826 |
+
"balanced_acc": True,
|
| 827 |
+
"mcc": True,
|
| 828 |
+
"macro_f1": True,
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
def aggregation(self):
|
| 832 |
+
return {
|
| 833 |
+
"acc": mean,
|
| 834 |
+
"acc_norm": mean,
|
| 835 |
+
"balanced_acc": balanced_mean,
|
| 836 |
+
"mcc": matthews_corrcoef,
|
| 837 |
+
"macro_f1": macro_f1,
|
| 838 |
+
}
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
class PerplexityTask(Task, abc.ABC):
|
| 842 |
+
def should_decontaminate(self):
|
| 843 |
+
"""Whether this task supports decontamination against model training set."""
|
| 844 |
+
return True
|
| 845 |
+
|
| 846 |
+
def has_training_docs(self):
|
| 847 |
+
return False
|
| 848 |
+
|
| 849 |
+
def fewshot_examples(self, k, rnd):
|
| 850 |
+
assert k == 0
|
| 851 |
+
return []
|
| 852 |
+
|
| 853 |
+
def fewshot_context(
|
| 854 |
+
self, doc, num_fewshot, provide_description=None, rnd=None, description=None
|
| 855 |
+
):
|
| 856 |
+
assert (
|
| 857 |
+
num_fewshot == 0
|
| 858 |
+
), "The number of fewshot examples must be 0 for perplexity tasks."
|
| 859 |
+
assert (
|
| 860 |
+
rnd is not None
|
| 861 |
+
), "A `random.Random` generator argument must be provided to `rnd`."
|
| 862 |
+
assert not provide_description, (
|
| 863 |
+
"The `provide_description` arg will be removed in future versions. To prepend "
|
| 864 |
+
"a custom description to the context, supply the corresponding string via the "
|
| 865 |
+
"`description` arg."
|
| 866 |
+
)
|
| 867 |
+
if provide_description is not None:
|
| 868 |
+
# nudge people to not specify it at all
|
| 869 |
+
print(
|
| 870 |
+
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
return ""
|
| 874 |
+
|
| 875 |
+
def higher_is_better(self):
|
| 876 |
+
return {
|
| 877 |
+
"word_perplexity": False,
|
| 878 |
+
"byte_perplexity": False,
|
| 879 |
+
"bits_per_byte": False,
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
def doc_to_decontamination_query(self, doc):
|
| 883 |
+
return doc
|
| 884 |
+
|
| 885 |
+
def doc_to_text(self, doc):
|
| 886 |
+
return ""
|
| 887 |
+
|
| 888 |
+
def doc_to_target(self, doc):
|
| 889 |
+
return doc
|
| 890 |
+
|
| 891 |
+
def construct_requests(self, doc, ctx):
|
| 892 |
+
assert not ctx
|
| 893 |
+
req = rf.loglikelihood_rolling(self.doc_to_target(doc))
|
| 894 |
+
return req
|
| 895 |
+
|
| 896 |
+
def process_results(self, doc, results):
|
| 897 |
+
(loglikelihood,) = results
|
| 898 |
+
words = self.count_words(doc)
|
| 899 |
+
bytes_ = self.count_bytes(doc)
|
| 900 |
+
return {
|
| 901 |
+
"word_perplexity": (loglikelihood, words),
|
| 902 |
+
"byte_perplexity": (loglikelihood, bytes_),
|
| 903 |
+
"bits_per_byte": (loglikelihood, bytes_),
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
def aggregation(self):
|
| 907 |
+
return {
|
| 908 |
+
"word_perplexity": weighted_perplexity,
|
| 909 |
+
"byte_perplexity": weighted_perplexity,
|
| 910 |
+
"bits_per_byte": bits_per_byte,
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
@classmethod
|
| 914 |
+
def count_bytes(cls, doc):
|
| 915 |
+
return len(doc.encode("utf-8"))
|
| 916 |
+
|
| 917 |
+
@classmethod
|
| 918 |
+
def count_words(cls, doc):
|
| 919 |
+
"""Downstream tasks with custom word boundaries should override this!"""
|
| 920 |
+
return len(re.split(r"\s+", doc))
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def hash_args(attr, args):
|
| 924 |
+
dat = json.dumps([attr] + list(args))
|
| 925 |
+
return hashlib.sha256(dat.encode("utf-8")).hexdigest()
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
class CacheHook:
|
| 929 |
+
def __init__(self, cachinglm):
|
| 930 |
+
if cachinglm is None:
|
| 931 |
+
self.dbdict = None
|
| 932 |
+
return
|
| 933 |
+
|
| 934 |
+
self.dbdict = cachinglm.dbdict
|
| 935 |
+
|
| 936 |
+
def add_partial(self, attr, req, res):
|
| 937 |
+
if self.dbdict is None:
|
| 938 |
+
return
|
| 939 |
+
hsh = hash_args(attr, req)
|
| 940 |
+
self.dbdict[hsh] = res
|
| 941 |
+
|
| 942 |
+
|
| 943 |
+
class CachingLM:
|
| 944 |
+
def __init__(self, lm, cache_db):
|
| 945 |
+
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
|
| 946 |
+
|
| 947 |
+
:param lm: LM
|
| 948 |
+
Underlying LM
|
| 949 |
+
:param cache_db: str
|
| 950 |
+
Path to cache db
|
| 951 |
+
"""
|
| 952 |
+
self.lm = lm
|
| 953 |
+
self.cache_db = cache_db
|
| 954 |
+
if os.path.dirname(cache_db):
|
| 955 |
+
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
|
| 956 |
+
self.dbdict = SqliteDict(cache_db, autocommit=True)
|
| 957 |
+
|
| 958 |
+
# add hook to lm
|
| 959 |
+
lm.set_cache_hook(self.get_cache_hook())
|
| 960 |
+
|
| 961 |
+
def __getattr__(self, attr):
|
| 962 |
+
def fn(requests):
|
| 963 |
+
res = []
|
| 964 |
+
remaining_reqs = []
|
| 965 |
+
|
| 966 |
+
# figure out which ones are cached and which ones are new
|
| 967 |
+
for req in requests:
|
| 968 |
+
hsh = hash_args(attr, req)
|
| 969 |
+
if hsh in self.dbdict:
|
| 970 |
+
ob = self.dbdict[hsh]
|
| 971 |
+
|
| 972 |
+
assert ob is not None
|
| 973 |
+
|
| 974 |
+
res.append(ob)
|
| 975 |
+
else:
|
| 976 |
+
res.append(None)
|
| 977 |
+
remaining_reqs.append(req)
|
| 978 |
+
|
| 979 |
+
# actually run the LM on the requests that do not have cached results
|
| 980 |
+
rem_res = getattr(self.lm, attr)(remaining_reqs)
|
| 981 |
+
|
| 982 |
+
# stick the new ones back into the list and also cache any of the new ones
|
| 983 |
+
resptr = 0
|
| 984 |
+
for req, r in zip(remaining_reqs, rem_res):
|
| 985 |
+
while res[resptr] is not None:
|
| 986 |
+
resptr += 1
|
| 987 |
+
|
| 988 |
+
res[resptr] = r
|
| 989 |
+
|
| 990 |
+
# caching
|
| 991 |
+
hsh = hash_args(attr, req)
|
| 992 |
+
self.dbdict[hsh] = r
|
| 993 |
+
self.dbdict.commit()
|
| 994 |
+
|
| 995 |
+
return res
|
| 996 |
+
|
| 997 |
+
return fn
|
| 998 |
+
|
| 999 |
+
def get_cache_hook(self):
|
| 1000 |
+
return CacheHook(self)
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
REQUEST_RETURN_LENGTHS = {
|
| 1004 |
+
"loglikelihood": 2,
|
| 1005 |
+
"greedy_until": None,
|
| 1006 |
+
"loglikelihood_rolling": None,
|
| 1007 |
+
}
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
class Request:
|
| 1011 |
+
def __init__(self, request_type, args, index=None):
|
| 1012 |
+
if request_type not in REQUEST_RETURN_LENGTHS.keys():
|
| 1013 |
+
raise NotImplementedError(
|
| 1014 |
+
"The request type {} is not implemented!".format(request_type)
|
| 1015 |
+
)
|
| 1016 |
+
|
| 1017 |
+
self.request_type = request_type
|
| 1018 |
+
self.args = args
|
| 1019 |
+
self.index = index
|
| 1020 |
+
|
| 1021 |
+
def __iter__(self):
|
| 1022 |
+
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
|
| 1023 |
+
raise IndexError("This request type does not return multiple arguments!")
|
| 1024 |
+
for i in range(REQUEST_RETURN_LENGTHS[self.request_type]):
|
| 1025 |
+
yield Request(self.request_type, self.args, i)
|
| 1026 |
+
|
| 1027 |
+
def __getitem__(self, i):
|
| 1028 |
+
if REQUEST_RETURN_LENGTHS[self.request_type] is None:
|
| 1029 |
+
raise IndexError("This request type does not return multiple arguments!")
|
| 1030 |
+
return Request(self.request_type, self.args, i)
|
| 1031 |
+
|
| 1032 |
+
def __eq__(self, other):
|
| 1033 |
+
return (
|
| 1034 |
+
self.request_type == other.request_type
|
| 1035 |
+
and self.args == other.args
|
| 1036 |
+
and self.index == other.index
|
| 1037 |
+
)
|
| 1038 |
+
|
| 1039 |
+
def __repr__(self):
|
| 1040 |
+
return f"Req_{self.request_type}{self.args}[{self.index}]\n"
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
class RequestFactory:
|
| 1044 |
+
def __getattr__(self, attr):
|
| 1045 |
+
def fn(*args):
|
| 1046 |
+
return Request(attr, args)
|
| 1047 |
+
|
| 1048 |
+
return fn
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
rf = RequestFactory()
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/evaluator.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import itertools
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import lm_eval.metrics
|
| 6 |
+
import lm_eval.models
|
| 7 |
+
import lm_eval.tasks
|
| 8 |
+
import lm_eval.base
|
| 9 |
+
from lm_eval.utils import positional_deprecated, run_task_tests
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@positional_deprecated
|
| 13 |
+
def simple_evaluate(
|
| 14 |
+
model,
|
| 15 |
+
model_args=None,
|
| 16 |
+
tasks=[],
|
| 17 |
+
num_fewshot=0,
|
| 18 |
+
batch_size=None,
|
| 19 |
+
device=None,
|
| 20 |
+
no_cache=False,
|
| 21 |
+
limit=None,
|
| 22 |
+
bootstrap_iters=100000,
|
| 23 |
+
description_dict=None,
|
| 24 |
+
check_integrity=False,
|
| 25 |
+
decontamination_ngrams_path=None,
|
| 26 |
+
verbose=False,
|
| 27 |
+
):
|
| 28 |
+
|
| 29 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 30 |
+
|
| 31 |
+
:param model: Union[str, LM]
|
| 32 |
+
Name of model or LM object, see lm_eval.models.get_model
|
| 33 |
+
:param model_args: Optional[str]
|
| 34 |
+
String arguments for each model class, see LM.create_from_arg_string.
|
| 35 |
+
Ignored if `model` argument is a LM object.
|
| 36 |
+
:param tasks: list[Union[str, Task]]
|
| 37 |
+
List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
| 38 |
+
:param num_fewshot: int or list of int
|
| 39 |
+
Number of examples in few-shot context
|
| 40 |
+
:param batch_size: int, optional
|
| 41 |
+
Batch size for model
|
| 42 |
+
:param device: str, optional
|
| 43 |
+
PyTorch device (e.g. "cpu" or "cuda:0") for running models
|
| 44 |
+
:param no_cache: bool
|
| 45 |
+
Whether or not to cache
|
| 46 |
+
:param limit: int or list of int, optional
|
| 47 |
+
Limit the number of examples per task (only use this for testing)
|
| 48 |
+
:param bootstrap_iters:
|
| 49 |
+
Number of iterations for bootstrap statistics
|
| 50 |
+
:param description_dict: dict[str, str]
|
| 51 |
+
Dictionary of custom task descriptions of the form: `task_name: description`
|
| 52 |
+
:param check_integrity: bool
|
| 53 |
+
Whether to run the relevant part of the test suite for the tasks
|
| 54 |
+
:return
|
| 55 |
+
Dictionary of results
|
| 56 |
+
"""
|
| 57 |
+
random.seed(1234)
|
| 58 |
+
np.random.seed(1234)
|
| 59 |
+
|
| 60 |
+
assert tasks != [], "No tasks specified"
|
| 61 |
+
|
| 62 |
+
if isinstance(model, str):
|
| 63 |
+
if model_args is None:
|
| 64 |
+
model_args = ""
|
| 65 |
+
lm = lm_eval.models.get_model(model).create_from_arg_string(
|
| 66 |
+
model_args, {"batch_size": batch_size, "device": device}
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
assert isinstance(model, lm_eval.base.LM)
|
| 70 |
+
lm = model
|
| 71 |
+
|
| 72 |
+
if not no_cache:
|
| 73 |
+
lm = lm_eval.base.CachingLM(
|
| 74 |
+
lm,
|
| 75 |
+
"lm_cache/"
|
| 76 |
+
+ model
|
| 77 |
+
+ "_"
|
| 78 |
+
+ model_args.replace("=", "-").replace(",", "_").replace("/", "-")
|
| 79 |
+
+ ".db",
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
task_dict = lm_eval.tasks.get_task_dict(tasks)
|
| 83 |
+
|
| 84 |
+
if check_integrity:
|
| 85 |
+
run_task_tests(task_list=tasks)
|
| 86 |
+
|
| 87 |
+
results = evaluate(
|
| 88 |
+
lm=lm,
|
| 89 |
+
task_dict=task_dict,
|
| 90 |
+
num_fewshot=num_fewshot,
|
| 91 |
+
limit=limit,
|
| 92 |
+
bootstrap_iters=bootstrap_iters,
|
| 93 |
+
description_dict=description_dict,
|
| 94 |
+
decontamination_ngrams_path=decontamination_ngrams_path,
|
| 95 |
+
verbose=verbose,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# add info about the model and few shot config
|
| 99 |
+
results["config"] = {
|
| 100 |
+
"model": model,
|
| 101 |
+
"model_args": model_args,
|
| 102 |
+
"num_fewshot": num_fewshot,
|
| 103 |
+
"batch_size": batch_size,
|
| 104 |
+
"device": device,
|
| 105 |
+
"no_cache": no_cache,
|
| 106 |
+
"limit": limit,
|
| 107 |
+
"bootstrap_iters": bootstrap_iters,
|
| 108 |
+
"description_dict": description_dict,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
return results
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
decontaminate_suffix = "_decontaminate"
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@positional_deprecated
|
| 118 |
+
def evaluate(
|
| 119 |
+
lm,
|
| 120 |
+
task_dict,
|
| 121 |
+
provide_description=None,
|
| 122 |
+
num_fewshot=0,
|
| 123 |
+
limit=None,
|
| 124 |
+
bootstrap_iters=100000,
|
| 125 |
+
description_dict=None,
|
| 126 |
+
decontamination_ngrams_path=None,
|
| 127 |
+
verbose=False,
|
| 128 |
+
):
|
| 129 |
+
"""Instantiate and evaluate a model on a list of tasks.
|
| 130 |
+
|
| 131 |
+
:param lm: obj
|
| 132 |
+
Language Model
|
| 133 |
+
:param task_dict: dict[str, Task]
|
| 134 |
+
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
|
| 135 |
+
:param provide_description: bool
|
| 136 |
+
Not implemented, and this option is deprecated and will be removed in a future version in favor of a different description providing method
|
| 137 |
+
:param num_fewshot: int or list of int
|
| 138 |
+
Number of examples in few-shot context
|
| 139 |
+
:param limit: int or list of int, optional
|
| 140 |
+
Limit the number of examples per task (only use this for testing)
|
| 141 |
+
:param bootstrap_iters:
|
| 142 |
+
Number of iterations for bootstrap statistics
|
| 143 |
+
:param description_dict: dict[str, str]
|
| 144 |
+
Dictionary of custom task descriptions of the form: `task_name: description`
|
| 145 |
+
:return
|
| 146 |
+
Dictionary of results
|
| 147 |
+
"""
|
| 148 |
+
# TODO: completely refactor this entire function to not be a huge mess, ideally breaking it down into smaller pieces
|
| 149 |
+
|
| 150 |
+
# TODO: todo: implement proper description-providing system
|
| 151 |
+
assert not provide_description # not implemented.
|
| 152 |
+
if provide_description is not None:
|
| 153 |
+
# nudge people to not specify it at all
|
| 154 |
+
print(
|
| 155 |
+
"WARNING: provide_description is deprecated and will be removed in a future version in favor of description_dict"
|
| 156 |
+
)
|
| 157 |
+
if isinstance(num_fewshot, list):
|
| 158 |
+
assert len(task_dict) == len(
|
| 159 |
+
num_fewshot
|
| 160 |
+
), f"The number of tasks ({len(task_dict)}) must be same as the number of elements in `num_fewshot` ({len(num_fewshot)})"
|
| 161 |
+
else:
|
| 162 |
+
# num_fewshot is int
|
| 163 |
+
num_fewshot = [num_fewshot] * len(task_dict)
|
| 164 |
+
if isinstance(limit, list):
|
| 165 |
+
assert len(task_dict) == len(
|
| 166 |
+
limit
|
| 167 |
+
), f"The number of tasks ({len(task_dict)}) must be same as the number of elements in `num_fewshot` ({len(limit)})"
|
| 168 |
+
else:
|
| 169 |
+
# limit is int or None
|
| 170 |
+
limit = [limit] * len(task_dict)
|
| 171 |
+
|
| 172 |
+
decontaminate = decontamination_ngrams_path is not None
|
| 173 |
+
|
| 174 |
+
task_dict_items = [
|
| 175 |
+
(name, task)
|
| 176 |
+
for name, task in task_dict.items()
|
| 177 |
+
if (task.has_validation_docs() or task.has_test_docs())
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
results = collections.defaultdict(dict)
|
| 181 |
+
versions = collections.defaultdict(dict)
|
| 182 |
+
|
| 183 |
+
requests = collections.defaultdict(list)
|
| 184 |
+
requests_origin = collections.defaultdict(list)
|
| 185 |
+
|
| 186 |
+
overlaps = collections.defaultdict(list) # {task_name: contaminated_docs}
|
| 187 |
+
|
| 188 |
+
# If we ever run into issues where the eval tasks don't fit in memory and we can't afford a machine with bigger
|
| 189 |
+
# memory, we can always modify this plumbing to support that, but I didn't want to include it just yet because
|
| 190 |
+
# over-engineering is bad (or we could make it write the requests to disk and then read them back out again
|
| 191 |
+
# - probably using an sqlite db because of all the moving parts we have
|
| 192 |
+
|
| 193 |
+
# TODO: we need unit tests & sanity checks or something to ensure that the return of `validation_docs` is stable
|
| 194 |
+
docs = {}
|
| 195 |
+
|
| 196 |
+
docs_for_decontamination = collections.defaultdict(list)
|
| 197 |
+
|
| 198 |
+
# get lists of each type of request
|
| 199 |
+
for idx, (task_name, task) in enumerate(task_dict_items):
|
| 200 |
+
versions[task_name] = task.VERSION
|
| 201 |
+
# default to test doc, fall back to val doc if validation unavailable
|
| 202 |
+
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
|
| 203 |
+
if task.has_test_docs():
|
| 204 |
+
task_doc_func = task.test_docs
|
| 205 |
+
task_set = "test" # Required for caching in the decontamination
|
| 206 |
+
elif task.has_validation_docs():
|
| 207 |
+
task_set = "val" # Required for caching in the decontamination
|
| 208 |
+
task_doc_func = task.validation_docs
|
| 209 |
+
else:
|
| 210 |
+
raise RuntimeError("Task has neither test_docs nor validation_docs")
|
| 211 |
+
|
| 212 |
+
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
|
| 213 |
+
task_docs = list(task_doc_func())
|
| 214 |
+
rnd = random.Random()
|
| 215 |
+
rnd.seed(42)
|
| 216 |
+
rnd.shuffle(task_docs)
|
| 217 |
+
|
| 218 |
+
description = (
|
| 219 |
+
description_dict[task_name]
|
| 220 |
+
if description_dict and task_name in description_dict
|
| 221 |
+
else ""
|
| 222 |
+
)
|
| 223 |
+
# set tokenizer inside task
|
| 224 |
+
if task.LOAD_TOKENIZER:
|
| 225 |
+
if isinstance(lm, lm_eval.base.CachingLM):
|
| 226 |
+
task.set_tokenizer(lm.lm.tokenizer)
|
| 227 |
+
else:
|
| 228 |
+
task.set_tokenizer(lm.tokenizer)
|
| 229 |
+
# set max_length to task object
|
| 230 |
+
task.max_length = (
|
| 231 |
+
lm.lm.max_length
|
| 232 |
+
if isinstance(lm, lm_eval.base.CachingLM)
|
| 233 |
+
else lm.max_length
|
| 234 |
+
)
|
| 235 |
+
task.max_gen_toks = (
|
| 236 |
+
lm.lm.max_gen_toks
|
| 237 |
+
if isinstance(lm, lm_eval.base.CachingLM)
|
| 238 |
+
else lm.max_gen_toks
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
limit_local = limit[idx]
|
| 242 |
+
if isinstance(limit_local, float):
|
| 243 |
+
limit_local = int(limit_local * len(task_docs))
|
| 244 |
+
print(
|
| 245 |
+
f"Use {limit_local}/{len(task_docs)} samples corresponding to the ratio of {limit[idx]}"
|
| 246 |
+
)
|
| 247 |
+
for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit_local)):
|
| 248 |
+
|
| 249 |
+
if decontaminate and task.should_decontaminate():
|
| 250 |
+
docs_for_decontamination[(task_name, task_set)].append(
|
| 251 |
+
task.doc_to_decontamination_query(doc)
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
docs[(task_name, doc_id)] = doc
|
| 255 |
+
ctx = task.fewshot_context(
|
| 256 |
+
doc=doc, num_fewshot=num_fewshot[idx], rnd=rnd, description=description
|
| 257 |
+
)
|
| 258 |
+
reqs = task.construct_requests(doc, ctx)
|
| 259 |
+
if not isinstance(reqs, (list, tuple)):
|
| 260 |
+
reqs = [reqs]
|
| 261 |
+
for i, req in enumerate(reqs):
|
| 262 |
+
requests[req.request_type].append(req)
|
| 263 |
+
# i: index in requests for a single task instance
|
| 264 |
+
# doc_id: unique id that we can get back to a doc using `docs`
|
| 265 |
+
requests_origin[req.request_type].append((i, task_name, doc, doc_id))
|
| 266 |
+
|
| 267 |
+
# Compare all tasks/sets at once to ensure a single training set scan
|
| 268 |
+
if decontaminate:
|
| 269 |
+
from lm_eval.decontamination.decontaminate import get_train_overlap
|
| 270 |
+
|
| 271 |
+
print("Finding train/test overlap, please wait...")
|
| 272 |
+
overlaps = get_train_overlap(
|
| 273 |
+
docs_for_decontamination, decontamination_ngrams_path, limit
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# all responses for each (task, doc)
|
| 277 |
+
process_res_queue = collections.defaultdict(list)
|
| 278 |
+
|
| 279 |
+
# execute each type of request
|
| 280 |
+
for reqtype, reqs in requests.items():
|
| 281 |
+
# TODO: right now, this code runs multiple separate LM requests for multiple Requests differing
|
| 282 |
+
# only in index. We could implement some kind of caching, but that would be more of a band-aid
|
| 283 |
+
# solution. we could also implement some kind of auto-grouping here;
|
| 284 |
+
# they should end up next to each other.
|
| 285 |
+
|
| 286 |
+
print("Running", reqtype, "requests")
|
| 287 |
+
resps = getattr(lm, reqtype)([req.args for req in reqs])
|
| 288 |
+
resps = [
|
| 289 |
+
x if req.index is None else x[req.index] for x, req in zip(resps, reqs)
|
| 290 |
+
]
|
| 291 |
+
|
| 292 |
+
for resp, (i, task_name, doc, doc_id) in zip(resps, requests_origin[reqtype]):
|
| 293 |
+
process_res_queue[(task_name, doc_id)].append((i, resp))
|
| 294 |
+
|
| 295 |
+
vals = collections.defaultdict(list)
|
| 296 |
+
# holds detailed responses for error analysis
|
| 297 |
+
details = collections.defaultdict(list)
|
| 298 |
+
|
| 299 |
+
# unpack results and sort back in order and return control to Task
|
| 300 |
+
for (task_name, doc_id), requests in process_res_queue.items():
|
| 301 |
+
requests.sort(key=lambda x: x[0])
|
| 302 |
+
requests = [x[1] for x in requests]
|
| 303 |
+
|
| 304 |
+
task = task_dict[task_name]
|
| 305 |
+
doc = docs[(task_name, doc_id)]
|
| 306 |
+
|
| 307 |
+
metrics = task.process_results(doc, requests)
|
| 308 |
+
if "details" in metrics:
|
| 309 |
+
details[task_name].append(metrics["details"])
|
| 310 |
+
del metrics["details"]
|
| 311 |
+
for metric, value in metrics.items():
|
| 312 |
+
vals[(task_name, metric)].append(value)
|
| 313 |
+
|
| 314 |
+
# Re-use the evaluation for the decontaminated set by just ignoring the overlaps
|
| 315 |
+
if decontaminate and task_name in overlaps:
|
| 316 |
+
if doc_id not in overlaps[task_name]:
|
| 317 |
+
vals[(task_name, metric + decontaminate_suffix)].append(value)
|
| 318 |
+
|
| 319 |
+
# aggregate results
|
| 320 |
+
for (task_name, metric), items in vals.items():
|
| 321 |
+
task = task_dict[task_name]
|
| 322 |
+
real_metric = metric # key when looking up the metric with task.aggregation
|
| 323 |
+
if metric.endswith(decontaminate_suffix):
|
| 324 |
+
real_metric = metric.replace(
|
| 325 |
+
decontaminate_suffix, ""
|
| 326 |
+
) # decontaminated still uses the same metric
|
| 327 |
+
results[task_name][metric] = task.aggregation()[real_metric](items)
|
| 328 |
+
|
| 329 |
+
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
|
| 330 |
+
# so we run them less iterations. still looking for a cleaner way to do this
|
| 331 |
+
|
| 332 |
+
stderr = lm_eval.metrics.stderr_for_metric(
|
| 333 |
+
metric=task.aggregation()[real_metric],
|
| 334 |
+
bootstrap_iters=min(bootstrap_iters, 1000)
|
| 335 |
+
if metric in ["bleu", "chrf", "ter"]
|
| 336 |
+
else bootstrap_iters,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if stderr is not None:
|
| 340 |
+
results[task_name][metric + "_stderr"] = stderr(items)
|
| 341 |
+
|
| 342 |
+
if verbose and task_name in details:
|
| 343 |
+
results[task_name]["details"] = details[task_name]
|
| 344 |
+
|
| 345 |
+
return {"results": dict(results), "versions": dict(versions)}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def make_table(result_dict):
|
| 349 |
+
"""Generate table of results."""
|
| 350 |
+
from pytablewriter import MarkdownTableWriter, LatexTableWriter
|
| 351 |
+
|
| 352 |
+
md_writer = MarkdownTableWriter()
|
| 353 |
+
latex_writer = LatexTableWriter()
|
| 354 |
+
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
|
| 355 |
+
latex_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
|
| 356 |
+
|
| 357 |
+
values = []
|
| 358 |
+
|
| 359 |
+
for k, dic in result_dict["results"].items():
|
| 360 |
+
version = result_dict["versions"][k]
|
| 361 |
+
for m, v in dic.items():
|
| 362 |
+
if m == "details":
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
if m.endswith("_stderr"):
|
| 366 |
+
continue
|
| 367 |
+
|
| 368 |
+
if m + "_stderr" in dic:
|
| 369 |
+
se = dic[m + "_stderr"]
|
| 370 |
+
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
|
| 371 |
+
else:
|
| 372 |
+
values.append([k, version, m, "%.4f" % v, "", ""])
|
| 373 |
+
k = ""
|
| 374 |
+
version = ""
|
| 375 |
+
md_writer.value_matrix = values
|
| 376 |
+
latex_writer.value_matrix = values
|
| 377 |
+
|
| 378 |
+
# todo: make latex table look good
|
| 379 |
+
# print(latex_writer.dumps())
|
| 380 |
+
|
| 381 |
+
return md_writer.dumps()
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/README.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Metric Card For Japanese SQuAD
|
| 2 |
+
heavily refer to https://github.com/huggingface/datasets/tree/main/metrics/squad
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__init__.py
ADDED
|
File without changes
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (169 Bytes). View file
|
|
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__pycache__/evaluate.cpython-310.pyc
ADDED
|
Binary file (4.06 kB). View file
|
|
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/__pycache__/jasquad.cpython-310.pyc
ADDED
|
Binary file (4.06 kB). View file
|
|
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/evaluate.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Official evaluation script for v1.1 of the SQuAD dataset. """
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
import string
|
| 7 |
+
import sys
|
| 8 |
+
from collections import Counter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def remove_punc(tokens):
|
| 12 |
+
exclude = (
|
| 13 |
+
"!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
| 14 |
+
)
|
| 15 |
+
exclude += string.punctuation
|
| 16 |
+
exclude = [*exclude]
|
| 17 |
+
return [tok for tok in tokens if tok not in exclude]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def normalize_answer(s):
|
| 21 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 22 |
+
import emoji
|
| 23 |
+
import neologdn
|
| 24 |
+
|
| 25 |
+
def white_space_fix(text):
|
| 26 |
+
return " ".join(text.split())
|
| 27 |
+
|
| 28 |
+
def remove_emoji(text):
|
| 29 |
+
text = "".join(["" if emoji.is_emoji(c) else c for c in text])
|
| 30 |
+
emoji_pattern = re.compile(
|
| 31 |
+
"["
|
| 32 |
+
"\U0001F600-\U0001F64F" # emoticons
|
| 33 |
+
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
| 34 |
+
"\U0001F680-\U0001F6FF" # transport & map symbols
|
| 35 |
+
"\U0001F1E0-\U0001F1FF" # flags (iOS)
|
| 36 |
+
"\U00002702-\U000027B0"
|
| 37 |
+
"]+",
|
| 38 |
+
flags=re.UNICODE,
|
| 39 |
+
)
|
| 40 |
+
return emoji_pattern.sub(r"", text)
|
| 41 |
+
|
| 42 |
+
return white_space_fix((neologdn.normalize(remove_emoji(s))))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def f1_score(prediction, ground_truth):
|
| 46 |
+
from fugashi import Tagger
|
| 47 |
+
|
| 48 |
+
tagger = Tagger("-Owakati")
|
| 49 |
+
prediction_tokens = remove_punc(tagger.parse(normalize_answer(prediction)).split())
|
| 50 |
+
ground_truth_tokens = remove_punc(
|
| 51 |
+
tagger.parse(normalize_answer(ground_truth)).split()
|
| 52 |
+
)
|
| 53 |
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
| 54 |
+
num_same = sum(common.values())
|
| 55 |
+
if num_same == 0:
|
| 56 |
+
return 0
|
| 57 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
| 58 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
| 59 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 60 |
+
return f1
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def exact_match_score(prediction, ground_truth):
|
| 64 |
+
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
| 68 |
+
scores_for_ground_truths = []
|
| 69 |
+
for ground_truth in ground_truths:
|
| 70 |
+
score = metric_fn(prediction, ground_truth)
|
| 71 |
+
scores_for_ground_truths.append(score)
|
| 72 |
+
return max(scores_for_ground_truths)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def evaluate(dataset, predictions):
|
| 76 |
+
f1 = exact_match = total = 0
|
| 77 |
+
for article in dataset:
|
| 78 |
+
for paragraph in article["paragraphs"]:
|
| 79 |
+
for qa in paragraph["qas"]:
|
| 80 |
+
total += 1
|
| 81 |
+
if qa["id"] not in predictions:
|
| 82 |
+
message = (
|
| 83 |
+
"Unanswered question " + qa["id"] + " will receive score 0."
|
| 84 |
+
)
|
| 85 |
+
print(message, file=sys.stderr)
|
| 86 |
+
continue
|
| 87 |
+
ground_truths = [x["text"] for x in qa["answers"]]
|
| 88 |
+
prediction = predictions[qa["id"]]
|
| 89 |
+
exact_match += metric_max_over_ground_truths(
|
| 90 |
+
exact_match_score, prediction, ground_truths
|
| 91 |
+
)
|
| 92 |
+
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
| 93 |
+
|
| 94 |
+
exact_match = 100.0 * exact_match / total
|
| 95 |
+
f1 = 100.0 * f1 / total
|
| 96 |
+
|
| 97 |
+
return {"exact_match": exact_match, "f1": f1}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
expected_version = "1.1"
|
| 102 |
+
parser = argparse.ArgumentParser(
|
| 103 |
+
description="Evaluation for Japanese SQuAD " + expected_version
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument("dataset_file", help="Dataset file")
|
| 106 |
+
parser.add_argument("prediction_file", help="Prediction File")
|
| 107 |
+
args = parser.parse_args()
|
| 108 |
+
with open(args.dataset_file) as dataset_file:
|
| 109 |
+
dataset_json = json.load(dataset_file)
|
| 110 |
+
if dataset_json["version"] != expected_version:
|
| 111 |
+
print(
|
| 112 |
+
"Evaluation expects v-"
|
| 113 |
+
+ expected_version
|
| 114 |
+
+ ", but got dataset with v-"
|
| 115 |
+
+ dataset_json["version"],
|
| 116 |
+
file=sys.stderr,
|
| 117 |
+
)
|
| 118 |
+
dataset = dataset_json["data"]
|
| 119 |
+
with open(args.prediction_file) as prediction_file:
|
| 120 |
+
predictions = json.load(prediction_file)
|
| 121 |
+
print(json.dumps(evaluate(dataset, predictions)))
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/jasquad.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The HuggingFace Datasets Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
""" SQuAD metric. """
|
| 15 |
+
|
| 16 |
+
import datasets
|
| 17 |
+
|
| 18 |
+
from .evaluate import evaluate
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_CITATION = """\
|
| 22 |
+
@inproceedings{Rajpurkar2016SQuAD10,
|
| 23 |
+
title={SQuAD: 100, 000+ Questions for Machine Comprehension of Text},
|
| 24 |
+
author={Pranav Rajpurkar and Jian Zhang and Konstantin Lopyrev and Percy Liang},
|
| 25 |
+
booktitle={EMNLP},
|
| 26 |
+
year={2016}
|
| 27 |
+
}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
_DESCRIPTION = """
|
| 31 |
+
This metric wrap the official scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
|
| 32 |
+
|
| 33 |
+
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by
|
| 34 |
+
crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span,
|
| 35 |
+
from the corresponding reading passage, or the question might be unanswerable.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
_KWARGS_DESCRIPTION = """
|
| 39 |
+
Computes SQuAD scores (F1 and EM).
|
| 40 |
+
Args:
|
| 41 |
+
predictions: List of question-answers dictionaries with the following key-values:
|
| 42 |
+
- 'id': id of the question-answer pair as given in the references (see below)
|
| 43 |
+
- 'prediction_text': the text of the answer
|
| 44 |
+
references: List of question-answers dictionaries with the following key-values:
|
| 45 |
+
- 'id': id of the question-answer pair (see above),
|
| 46 |
+
- 'answers': a Dict in the SQuAD dataset format
|
| 47 |
+
{
|
| 48 |
+
'text': list of possible texts for the answer, as a list of strings
|
| 49 |
+
'answer_start': list of start positions for the answer, as a list of ints
|
| 50 |
+
}
|
| 51 |
+
Note that answer_start values are not taken into account to compute the metric.
|
| 52 |
+
Returns:
|
| 53 |
+
'exact_match': Exact match (the normalized answer exactly match the gold answer)
|
| 54 |
+
'f1': The F-score of predicted tokens versus the gold answer
|
| 55 |
+
Examples:
|
| 56 |
+
|
| 57 |
+
>>> predictions = [{'prediction_text': '1976', 'id': '56e10a3be3433e1400422b22'}]
|
| 58 |
+
>>> references = [{'answers': {'answer_start': [97], 'text': ['1976']}, 'id': '56e10a3be3433e1400422b22'}]
|
| 59 |
+
>>> squad_metric = datasets.load_metric("squad")
|
| 60 |
+
>>> results = squad_metric.compute(predictions=predictions, references=references)
|
| 61 |
+
>>> print(results)
|
| 62 |
+
{'exact_match': 100.0, 'f1': 100.0}
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 67 |
+
class JaSquad(datasets.Metric):
|
| 68 |
+
def _info(self):
|
| 69 |
+
return datasets.MetricInfo(
|
| 70 |
+
description=_DESCRIPTION,
|
| 71 |
+
citation=_CITATION,
|
| 72 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
| 73 |
+
features=datasets.Features(
|
| 74 |
+
{
|
| 75 |
+
"predictions": {
|
| 76 |
+
"id": datasets.Value("string"),
|
| 77 |
+
"prediction_text": datasets.Value("string"),
|
| 78 |
+
},
|
| 79 |
+
"references": {
|
| 80 |
+
"id": datasets.Value("string"),
|
| 81 |
+
"answers": datasets.features.Sequence(
|
| 82 |
+
{
|
| 83 |
+
"text": datasets.Value("string"),
|
| 84 |
+
"answer_start": datasets.Value("int32"),
|
| 85 |
+
}
|
| 86 |
+
),
|
| 87 |
+
},
|
| 88 |
+
}
|
| 89 |
+
),
|
| 90 |
+
codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
| 91 |
+
reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _compute(self, predictions, references):
|
| 95 |
+
pred_dict = {
|
| 96 |
+
prediction["id"]: prediction["prediction_text"]
|
| 97 |
+
for prediction in predictions
|
| 98 |
+
}
|
| 99 |
+
dataset = [
|
| 100 |
+
{
|
| 101 |
+
"paragraphs": [
|
| 102 |
+
{
|
| 103 |
+
"qas": [
|
| 104 |
+
{
|
| 105 |
+
"answers": [
|
| 106 |
+
{"text": answer_text}
|
| 107 |
+
for answer_text in ref["answers"]["text"]
|
| 108 |
+
],
|
| 109 |
+
"id": ref["id"],
|
| 110 |
+
}
|
| 111 |
+
for ref in references
|
| 112 |
+
]
|
| 113 |
+
}
|
| 114 |
+
]
|
| 115 |
+
}
|
| 116 |
+
]
|
| 117 |
+
score = getattr(self, "cached_s", None)
|
| 118 |
+
if score:
|
| 119 |
+
cached_p = getattr(self, "cached_p", None)
|
| 120 |
+
cached_r = getattr(self, "cached_r", None)
|
| 121 |
+
if cached_p == predictions and cached_r == references:
|
| 122 |
+
return score
|
| 123 |
+
|
| 124 |
+
score = evaluate(dataset=dataset, predictions=pred_dict)
|
| 125 |
+
setattr(self, "cached_s", score)
|
| 126 |
+
setattr(self, "cached_p", list(predictions))
|
| 127 |
+
setattr(self, "cached_r", list(references))
|
| 128 |
+
return score
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/jasquad/requirements.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
git+https://github.com/huggingface/evaluate@{COMMIT_PLACEHOLDER}
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/metrics.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections.abc import Iterable
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import sacrebleu
|
| 6 |
+
import sklearn.metrics
|
| 7 |
+
import random
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def mean(arr):
|
| 12 |
+
return sum(arr) / len(arr)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def pop_stddev(arr):
|
| 16 |
+
mu = mean(arr)
|
| 17 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sample_stddev(arr):
|
| 21 |
+
mu = mean(arr)
|
| 22 |
+
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def mean_stderr(arr):
|
| 26 |
+
return sample_stddev(arr) / math.sqrt(len(arr))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def median(arr):
|
| 30 |
+
return arr[len(arr) // 2]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def balanced_mean(arr):
|
| 34 |
+
# each entry is of the form (acc score, class label)
|
| 35 |
+
# first group the results
|
| 36 |
+
by_class = defaultdict(list)
|
| 37 |
+
for acc, label in arr:
|
| 38 |
+
by_class[label].append(acc)
|
| 39 |
+
|
| 40 |
+
# calculate class averages
|
| 41 |
+
avgs = []
|
| 42 |
+
for key, vals in by_class.items():
|
| 43 |
+
avgs.append(sum(vals) / len(vals))
|
| 44 |
+
|
| 45 |
+
# average the class values
|
| 46 |
+
return sum(avgs) / len(avgs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def matthews_corrcoef(items):
|
| 50 |
+
unzipped_list = list(zip(*items))
|
| 51 |
+
golds = unzipped_list[0]
|
| 52 |
+
preds = unzipped_list[1]
|
| 53 |
+
return sklearn.metrics.matthews_corrcoef(golds, preds)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def f1_score(items):
|
| 57 |
+
unzipped_list = list(zip(*items))
|
| 58 |
+
golds = unzipped_list[0]
|
| 59 |
+
preds = unzipped_list[1]
|
| 60 |
+
fscore = sklearn.metrics.f1_score(golds, preds)
|
| 61 |
+
|
| 62 |
+
return np.max(fscore)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def macro_f1(items):
|
| 66 |
+
# this is different from f1-score which uses default binary avg
|
| 67 |
+
unzipped_list = list(zip(*items))
|
| 68 |
+
golds = unzipped_list[0]
|
| 69 |
+
preds = unzipped_list[1]
|
| 70 |
+
fscore = sklearn.metrics.f1_score(golds, preds, average="macro")
|
| 71 |
+
|
| 72 |
+
return fscore
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def acc_all(items):
|
| 76 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 77 |
+
question_scoring_dict = {}
|
| 78 |
+
preds = list(zip(*items))[0]
|
| 79 |
+
docs = list(zip(*items))[1]
|
| 80 |
+
|
| 81 |
+
for doc, pred in zip(docs, preds):
|
| 82 |
+
paragraph_id = doc["idx"]["paragraph"]
|
| 83 |
+
question_id = doc["idx"]["question"]
|
| 84 |
+
if (paragraph_id, question_id) not in question_scoring_dict:
|
| 85 |
+
question_scoring_dict[(paragraph_id, question_id)] = []
|
| 86 |
+
|
| 87 |
+
gold_label = doc["label"] == 1
|
| 88 |
+
|
| 89 |
+
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
|
| 90 |
+
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
|
| 91 |
+
return acc
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def acc_all_stderr(items):
|
| 95 |
+
# Only count as correct if all answers are labeled correctly for each question
|
| 96 |
+
question_scoring_dict = {}
|
| 97 |
+
preds = list(zip(*items))[0]
|
| 98 |
+
docs = list(zip(*items))[1]
|
| 99 |
+
|
| 100 |
+
for doc, pred in zip(docs, preds):
|
| 101 |
+
question_id = doc["idx"]["question"]
|
| 102 |
+
if question_id not in question_scoring_dict:
|
| 103 |
+
question_scoring_dict[question_id] = []
|
| 104 |
+
|
| 105 |
+
gold_label = doc["label"] == 1
|
| 106 |
+
question_scoring_dict[question_id].append(gold_label == pred)
|
| 107 |
+
|
| 108 |
+
acc = mean_stderr([int(all(x)) for x in question_scoring_dict.values()])
|
| 109 |
+
return acc
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
| 113 |
+
"""Compute max metric between prediction and each ground truth."""
|
| 114 |
+
scores_for_ground_truths = []
|
| 115 |
+
for ground_truth in ground_truths:
|
| 116 |
+
score = metric_fn(prediction, ground_truth)
|
| 117 |
+
scores_for_ground_truths.append(score)
|
| 118 |
+
return max(scores_for_ground_truths)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def perplexity(items):
|
| 122 |
+
return math.exp(-mean(items))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def weighted_mean(items):
|
| 126 |
+
a, b = zip(*items)
|
| 127 |
+
return sum(a) / sum(b)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def weighted_perplexity(items):
|
| 131 |
+
return math.exp(-weighted_mean(items))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def bits_per_byte(items):
|
| 135 |
+
return -weighted_mean(items) / math.log(2)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def bleu(items):
|
| 139 |
+
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
|
| 140 |
+
for evaluating a generated sentence to a reference sentence. It counts matching
|
| 141 |
+
n-grams in the candidate translation to n-grams in the reference text, where
|
| 142 |
+
1-gram or unigram would be each token and a bigram comparison would be each
|
| 143 |
+
word pair. The comparison is made regardless of word order
|
| 144 |
+
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
|
| 145 |
+
Paper: https://www.aclweb.org/anthology/P02-1040/
|
| 146 |
+
|
| 147 |
+
Higher is better
|
| 148 |
+
"""
|
| 149 |
+
refs = list(zip(*items))[0]
|
| 150 |
+
preds = list(zip(*items))[1]
|
| 151 |
+
refs, preds = _sacreformat(refs, preds)
|
| 152 |
+
return sacrebleu.corpus_bleu(preds, refs).score
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def chrf(items):
|
| 156 |
+
"""chrF++ is a tool for automatic evaluation of machine translation output
|
| 157 |
+
based on character n-gram precision and recall enhanced with word n-grams.
|
| 158 |
+
Source: https://github.com/m-popovic/chrF
|
| 159 |
+
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
|
| 160 |
+
|
| 161 |
+
Higher is better # TODO I think
|
| 162 |
+
"""
|
| 163 |
+
refs = list(zip(*items))[0]
|
| 164 |
+
preds = list(zip(*items))[1]
|
| 165 |
+
refs, preds = _sacreformat(refs, preds)
|
| 166 |
+
return sacrebleu.corpus_chrf(preds, refs).score
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def ter(items):
|
| 170 |
+
"""Translation Error Rate is an error metric for machine translation that
|
| 171 |
+
measures the number of edits required to change a system output into one
|
| 172 |
+
of the references
|
| 173 |
+
Source: http://www.cs.umd.edu/~snover/tercom/
|
| 174 |
+
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
|
| 175 |
+
|
| 176 |
+
Lower is better
|
| 177 |
+
"""
|
| 178 |
+
refs = list(zip(*items))[0]
|
| 179 |
+
preds = list(zip(*items))[1]
|
| 180 |
+
refs, preds = _sacreformat(refs, preds)
|
| 181 |
+
return sacrebleu.corpus_ter(preds, refs).score
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def is_non_str_iterable(obj):
|
| 185 |
+
return isinstance(obj, Iterable) and not isinstance(obj, str)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _sacreformat(refs, preds):
|
| 189 |
+
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
|
| 190 |
+
# Sacrebleu expects (List[str], List[List[str])
|
| 191 |
+
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
|
| 192 |
+
|
| 193 |
+
# Note [ref1_stream] is the first reference for each pred.
|
| 194 |
+
# So lists are size N and (M, N) for N preds and M possible refs for each pred
|
| 195 |
+
# This is a different order of dimensions that I would expect
|
| 196 |
+
|
| 197 |
+
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
|
| 198 |
+
# Must become List[List[str]] with the inner list corresponding to preds
|
| 199 |
+
if not is_non_str_iterable(refs):
|
| 200 |
+
refs = list(refs)
|
| 201 |
+
if not is_non_str_iterable(refs[0]):
|
| 202 |
+
refs = [[ref] for ref in refs]
|
| 203 |
+
refs = list(zip(*refs))
|
| 204 |
+
# Note the number of refs in each ref list much match the number of preds
|
| 205 |
+
|
| 206 |
+
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
|
| 207 |
+
if not is_non_str_iterable(preds):
|
| 208 |
+
preds = list(preds)
|
| 209 |
+
if is_non_str_iterable(preds[0]):
|
| 210 |
+
assert len(preds[0]) == 1, f"Pred must be a str, was {preds[0]}"
|
| 211 |
+
preds = [pred[0] for pred in preds]
|
| 212 |
+
|
| 213 |
+
return refs, preds
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# stderr stuff
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class _bootstrap_internal:
|
| 220 |
+
def __init__(self, f, n):
|
| 221 |
+
self.f = f
|
| 222 |
+
self.n = n
|
| 223 |
+
|
| 224 |
+
def __call__(self, v):
|
| 225 |
+
i, xs = v
|
| 226 |
+
rnd = random.Random()
|
| 227 |
+
rnd.seed(i)
|
| 228 |
+
res = []
|
| 229 |
+
for _ in range(self.n):
|
| 230 |
+
res.append(self.f(rnd.choices(xs, k=len(xs))))
|
| 231 |
+
return res
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def bootstrap_stderr(f, xs, iters):
|
| 235 |
+
import multiprocessing as mp
|
| 236 |
+
|
| 237 |
+
pool = mp.Pool(mp.cpu_count())
|
| 238 |
+
# this gives a biased estimate of the stderr (i.e w/ the mean, it gives something
|
| 239 |
+
# equivalent to stderr calculated without Bessel's correction in the stddev.
|
| 240 |
+
# Unfortunately, I haven't been able to figure out what the right correction is
|
| 241 |
+
# to make the bootstrap unbiased - i considered multiplying by sqrt(n/(n-1)) but
|
| 242 |
+
# that would be ad-hoc and I can't prove that that would actually be an unbiased estimator)
|
| 243 |
+
# Thankfully, shouldn't matter because our samples are pretty big usually anyways
|
| 244 |
+
res = []
|
| 245 |
+
chunk_size = min(1000, iters)
|
| 246 |
+
from tqdm import tqdm
|
| 247 |
+
|
| 248 |
+
print("bootstrapping for stddev:", f.__name__)
|
| 249 |
+
for bootstrap in tqdm(
|
| 250 |
+
pool.imap(
|
| 251 |
+
_bootstrap_internal(f, chunk_size),
|
| 252 |
+
[(i, xs) for i in range(iters // chunk_size)],
|
| 253 |
+
),
|
| 254 |
+
total=iters // chunk_size,
|
| 255 |
+
):
|
| 256 |
+
# sample w replacement
|
| 257 |
+
res.extend(bootstrap)
|
| 258 |
+
|
| 259 |
+
pool.close()
|
| 260 |
+
return sample_stddev(res)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def stderr_for_metric(metric, bootstrap_iters):
|
| 264 |
+
bootstrappable = [
|
| 265 |
+
median,
|
| 266 |
+
matthews_corrcoef,
|
| 267 |
+
f1_score,
|
| 268 |
+
perplexity,
|
| 269 |
+
bleu,
|
| 270 |
+
chrf,
|
| 271 |
+
ter,
|
| 272 |
+
]
|
| 273 |
+
|
| 274 |
+
if metric in bootstrappable:
|
| 275 |
+
return lambda x: bootstrap_stderr(metric, x, iters=bootstrap_iters)
|
| 276 |
+
|
| 277 |
+
stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
|
| 278 |
+
|
| 279 |
+
return stderr.get(metric, None)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def yesno(x):
|
| 283 |
+
if x:
|
| 284 |
+
return "yes"
|
| 285 |
+
else:
|
| 286 |
+
return "no"
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/prompts.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def jslm_beta(task):
|
| 2 |
+
"""JSLM Beta uses a different prompt for JCommonSenseQA."""
|
| 3 |
+
if task == "jcommonsenseqa":
|
| 4 |
+
return "0.2.1"
|
| 5 |
+
else:
|
| 6 |
+
return "0.2"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
PROMPT_CODES = {
|
| 10 |
+
"user": "0.0",
|
| 11 |
+
"jgpt": "0.1",
|
| 12 |
+
"fintan": "0.2",
|
| 13 |
+
"fintan2": "0.2.1",
|
| 14 |
+
"ja-alpaca": "0.3",
|
| 15 |
+
"rinna-sft": "0.4",
|
| 16 |
+
"rinna-bilingual": "0.5",
|
| 17 |
+
"llama2": "0.6",
|
| 18 |
+
"jslm-beta": jslm_beta,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_prompt_code(short_name, task=None):
|
| 23 |
+
"""Get the prompt code given a short name.
|
| 24 |
+
|
| 25 |
+
Usually, this is a simple dictionary lookup. But it can depend on the task
|
| 26 |
+
sometimes.
|
| 27 |
+
"""
|
| 28 |
+
code = PROMPT_CODES[short_name]
|
| 29 |
+
|
| 30 |
+
if callable(code):
|
| 31 |
+
return callable(task)
|
| 32 |
+
else:
|
| 33 |
+
return code
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/suites/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Functionality related to "eval suites". A suite is a collection of tasks with
|
| 2 |
+
# options pre-configured. Different models can be run with the same suite to
|
| 3 |
+
# compare them.
|
| 4 |
+
import configparser
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# This file is the path where suite configs go
|
| 11 |
+
SUITE_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / "configs"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class TaskSpec:
|
| 16 |
+
"""Specification of a task in an eval suite.
|
| 17 |
+
|
| 18 |
+
A suite is a list of these specs, plus a prompt."""
|
| 19 |
+
|
| 20 |
+
# The real arguments have to be massaged into messy strings and parallel
|
| 21 |
+
# lists, but this is a more reasonable structure - we can handle conversion
|
| 22 |
+
# separately.
|
| 23 |
+
|
| 24 |
+
name: str
|
| 25 |
+
fewshot: int
|
| 26 |
+
version: Optional[str]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_suite(name):
|
| 30 |
+
"""Read in configuration for a test suite.
|
| 31 |
+
|
| 32 |
+
A suite will have a config file named something like `my_suite.conf`. For
|
| 33 |
+
each task in the file, a version, fewshot config, and any other details
|
| 34 |
+
will be specified.
|
| 35 |
+
|
| 36 |
+
Example entry:
|
| 37 |
+
|
| 38 |
+
[tasks.mgsm]
|
| 39 |
+
version = 1.0
|
| 40 |
+
fewshot = 5
|
| 41 |
+
"""
|
| 42 |
+
conf = configparser.ConfigParser()
|
| 43 |
+
conf.read(SUITE_DIR / (name + ".conf"))
|
| 44 |
+
|
| 45 |
+
specs = []
|
| 46 |
+
for key, val in conf.items():
|
| 47 |
+
if not key.startswith("tasks."):
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
spec = TaskSpec(
|
| 51 |
+
name=key.split(".", 1)[1],
|
| 52 |
+
version=val.get("version", None),
|
| 53 |
+
fewshot=int(val["fewshot"]),
|
| 54 |
+
)
|
| 55 |
+
specs.append(spec)
|
| 56 |
+
return specs
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/suites/configs/ja8.conf
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is the standard eight-task eval suite.
|
| 2 |
+
|
| 3 |
+
[tasks.mgsm]
|
| 4 |
+
version = 1.0
|
| 5 |
+
fewshot = 5
|
| 6 |
+
|
| 7 |
+
[tasks.xwinograd_ja]
|
| 8 |
+
# this has no version
|
| 9 |
+
fewshot = 0
|
| 10 |
+
|
| 11 |
+
[tasks.xlsum_ja]
|
| 12 |
+
version = 1.0
|
| 13 |
+
fewshot = 1
|
| 14 |
+
|
| 15 |
+
[tasks.jaqket_v2]
|
| 16 |
+
version = 0.2
|
| 17 |
+
fewshot = 1
|
| 18 |
+
|
| 19 |
+
[tasks.marc_ja]
|
| 20 |
+
version = 1.1
|
| 21 |
+
fewshot = 3
|
| 22 |
+
|
| 23 |
+
[tasks.jnli]
|
| 24 |
+
version = 1.3
|
| 25 |
+
fewshot = 3
|
| 26 |
+
|
| 27 |
+
[tasks.jcommonsenseqa]
|
| 28 |
+
version = 1.1
|
| 29 |
+
fewshot = 3
|
| 30 |
+
|
| 31 |
+
[tasks.jsquad]
|
| 32 |
+
version = 1.1
|
| 33 |
+
fewshot = 2
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/coqa.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CoQA: A Conversational Question Answering Challenge
|
| 3 |
+
https://arxiv.org/pdf/1808.07042.pdf
|
| 4 |
+
|
| 5 |
+
CoQA is a large-scale dataset for building Conversational Question Answering
|
| 6 |
+
systems. The goal of the CoQA challenge is to measure the ability of machines to
|
| 7 |
+
understand a text passage and answer a series of interconnected questions that
|
| 8 |
+
appear in a conversation.
|
| 9 |
+
|
| 10 |
+
Homepage: https://stanfordnlp.github.io/coqa/
|
| 11 |
+
"""
|
| 12 |
+
import inspect
|
| 13 |
+
import transformers.data.metrics.squad_metrics as squad_metrics
|
| 14 |
+
import lm_eval.datasets.coqa.coqa
|
| 15 |
+
from lm_eval.base import Task, rf, mean
|
| 16 |
+
from itertools import zip_longest
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_CITATION = """
|
| 20 |
+
@misc{reddy2018coqa,
|
| 21 |
+
title={CoQA: A Conversational Question Answering Challenge},
|
| 22 |
+
author={Siva Reddy and Danqi Chen and Christopher D. Manning},
|
| 23 |
+
year={2018},
|
| 24 |
+
eprint={1808.07042},
|
| 25 |
+
archivePrefix={arXiv},
|
| 26 |
+
primaryClass={cs.CL}
|
| 27 |
+
}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CoQA(Task):
|
| 32 |
+
VERSION = 1
|
| 33 |
+
DATASET_PATH = inspect.getfile(lm_eval.datasets.coqa.coqa)
|
| 34 |
+
DATASET_NAME = None
|
| 35 |
+
|
| 36 |
+
def has_training_docs(self):
|
| 37 |
+
return True
|
| 38 |
+
|
| 39 |
+
def has_validation_docs(self):
|
| 40 |
+
return True
|
| 41 |
+
|
| 42 |
+
def has_test_docs(self):
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
def training_docs(self):
|
| 46 |
+
return self.dataset["train"]
|
| 47 |
+
|
| 48 |
+
def validation_docs(self):
|
| 49 |
+
return self.dataset["validation"]
|
| 50 |
+
|
| 51 |
+
def test_docs(self):
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def doc_to_text(self, doc):
|
| 55 |
+
# Given a passage p, the conversation history {q1, a1, . . . qi−1, ai−1}
|
| 56 |
+
# and a question qi, the task is to predict the answer ai
|
| 57 |
+
doc_text = doc["story"] + "\n\n"
|
| 58 |
+
for (q, a) in zip_longest(
|
| 59 |
+
doc["questions"]["input_text"], doc["answers"]["input_text"][:-1]
|
| 60 |
+
): # omit target answer ai
|
| 61 |
+
question = f"Q: {q}\n\n"
|
| 62 |
+
answer = f"A: {a}\n\n" if a is not None else "A:"
|
| 63 |
+
doc_text += question + answer
|
| 64 |
+
return doc_text
|
| 65 |
+
|
| 66 |
+
def should_decontaminate(self):
|
| 67 |
+
return True
|
| 68 |
+
|
| 69 |
+
def doc_to_decontamination_query(self, doc):
|
| 70 |
+
return doc["story"] + " " + "\n".join(doc["questions"]["input_text"])
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def get_answers(cls, doc, turn_id):
|
| 74 |
+
# Returns unique answers and valid alternatives (Some questions in CoQA have multiple valid answers).
|
| 75 |
+
answers = []
|
| 76 |
+
answer_forturn = doc["answers"]["input_text"][turn_id - 1]
|
| 77 |
+
answers.append(answer_forturn)
|
| 78 |
+
|
| 79 |
+
additional_answers = doc.get("additional_answers")
|
| 80 |
+
if additional_answers:
|
| 81 |
+
for key in additional_answers:
|
| 82 |
+
additional_answer_for_turn = additional_answers[key]["input_text"][
|
| 83 |
+
turn_id - 1
|
| 84 |
+
]
|
| 85 |
+
if additional_answer_for_turn.lower() not in map(str.lower, answers):
|
| 86 |
+
answers.append(additional_answer_for_turn)
|
| 87 |
+
return answers
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def get_answer_choice(self, raw_text):
|
| 91 |
+
# Function maps answers to CoQA answer categories
|
| 92 |
+
# ~ 1/5 of the CoQA answers are Yes/No
|
| 93 |
+
# ~ 2/3 of the CoQA answers are span-based
|
| 94 |
+
# (answers overlap with the passage ignoring punctuation and case mismatch)
|
| 95 |
+
if raw_text == "unknown":
|
| 96 |
+
return "0"
|
| 97 |
+
if squad_metrics.normalize_answer(raw_text) == "yes":
|
| 98 |
+
return "1"
|
| 99 |
+
if squad_metrics.normalize_answer(raw_text) == "no":
|
| 100 |
+
return "2"
|
| 101 |
+
return "3" # Not a yes/no question
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def compute_scores(gold_list, pred):
|
| 105 |
+
# tests for exact match and on the normalised answer (compute_exact)
|
| 106 |
+
# test for overlap (compute_f1)
|
| 107 |
+
f1_sum = 0.0
|
| 108 |
+
em_sum = 0.0
|
| 109 |
+
if len(gold_list) > 1:
|
| 110 |
+
for i in range(len(gold_list)):
|
| 111 |
+
gold_answers = gold_list[0:i] + gold_list[i + 1 :]
|
| 112 |
+
# predictions compared against (n) golds and take maximum
|
| 113 |
+
em_sum += max(
|
| 114 |
+
squad_metrics.compute_exact(a, pred) for a in gold_answers
|
| 115 |
+
)
|
| 116 |
+
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_answers)
|
| 117 |
+
else:
|
| 118 |
+
em_sum += max(squad_metrics.compute_exact(a, pred) for a in gold_list)
|
| 119 |
+
f1_sum += max(squad_metrics.compute_f1(a, pred) for a in gold_list)
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"em": em_sum / max(1, len(gold_list)),
|
| 123 |
+
"f1": f1_sum / max(1, len(gold_list)),
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def doc_to_target(self, doc, turnid=None):
|
| 127 |
+
# Default to prediction of last turn.
|
| 128 |
+
if turnid is None:
|
| 129 |
+
turnid = len(doc["questions"]["input_text"])
|
| 130 |
+
raw_text = doc["answers"]["input_text"][turnid - 1]
|
| 131 |
+
return " " + raw_text
|
| 132 |
+
|
| 133 |
+
def construct_requests(self, doc, ctx):
|
| 134 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 135 |
+
Requests which will be sent to the LM.
|
| 136 |
+
|
| 137 |
+
:param doc:
|
| 138 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 139 |
+
:param ctx: str
|
| 140 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 141 |
+
language description, as well as the few shot examples, and the question
|
| 142 |
+
part of the document for `doc`.
|
| 143 |
+
"""
|
| 144 |
+
cont_request = rf.greedy_until(ctx, ["\nQ:"])
|
| 145 |
+
return cont_request
|
| 146 |
+
|
| 147 |
+
def process_results(self, doc, results):
|
| 148 |
+
"""Take a single document and the LM results and evaluates, returning a
|
| 149 |
+
dict where keys are the names of submetrics and values are the values of
|
| 150 |
+
the metric for that one document
|
| 151 |
+
|
| 152 |
+
:param doc:
|
| 153 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 154 |
+
:param results:
|
| 155 |
+
The results of the requests created in construct_requests.
|
| 156 |
+
"""
|
| 157 |
+
turn_id = len(doc["questions"]["input_text"])
|
| 158 |
+
gold_list = self.get_answers(doc, turn_id)
|
| 159 |
+
pred = results[0].strip().split("\n")[0]
|
| 160 |
+
|
| 161 |
+
scores = self.compute_scores(gold_list, pred)
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"f1": scores["f1"],
|
| 165 |
+
"em": scores["em"],
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
def higher_is_better(self):
|
| 169 |
+
return {
|
| 170 |
+
"f1": True,
|
| 171 |
+
"em": True,
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
def aggregation(self):
|
| 175 |
+
return {
|
| 176 |
+
"f1": mean,
|
| 177 |
+
"em": mean,
|
| 178 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/hellaswag.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HellaSwag: Can a Machine Really Finish Your Sentence?
|
| 3 |
+
https://arxiv.org/pdf/1905.07830.pdf
|
| 4 |
+
|
| 5 |
+
Hellaswag is a commonsense inference challenge dataset. Though its questions are
|
| 6 |
+
trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). This is
|
| 7 |
+
achieved via Adversarial Filtering (AF), a data collection paradigm wherein a
|
| 8 |
+
series of discriminators iteratively select an adversarial set of machine-generated
|
| 9 |
+
wrong answers. AF proves to be surprisingly robust. The key insight is to scale up
|
| 10 |
+
the length and complexity of the dataset examples towards a critical 'Goldilocks'
|
| 11 |
+
zone wherein generated text is ridiculous to humans, yet often misclassified by
|
| 12 |
+
state-of-the-art models.
|
| 13 |
+
|
| 14 |
+
Homepage: https://rowanzellers.com/hellaswag/
|
| 15 |
+
"""
|
| 16 |
+
import re
|
| 17 |
+
from lm_eval.base import MultipleChoiceTask
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_CITATION = """
|
| 21 |
+
@inproceedings{zellers2019hellaswag,
|
| 22 |
+
title={HellaSwag: Can a Machine Really Finish Your Sentence?},
|
| 23 |
+
author={Zellers, Rowan and Holtzman, Ari and Bisk, Yonatan and Farhadi, Ali and Choi, Yejin},
|
| 24 |
+
booktitle ={Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics},
|
| 25 |
+
year={2019}
|
| 26 |
+
}
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class HellaSwag(MultipleChoiceTask):
|
| 31 |
+
VERSION = 0
|
| 32 |
+
DATASET_PATH = "hellaswag"
|
| 33 |
+
DATASET_NAME = None
|
| 34 |
+
|
| 35 |
+
def has_training_docs(self):
|
| 36 |
+
return True
|
| 37 |
+
|
| 38 |
+
def has_validation_docs(self):
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
def has_test_docs(self):
|
| 42 |
+
return False
|
| 43 |
+
|
| 44 |
+
def training_docs(self):
|
| 45 |
+
if self._training_docs is None:
|
| 46 |
+
self._training_docs = list(map(self._process_doc, self.dataset["train"]))
|
| 47 |
+
return self._training_docs
|
| 48 |
+
|
| 49 |
+
def validation_docs(self):
|
| 50 |
+
return map(self._process_doc, self.dataset["validation"])
|
| 51 |
+
|
| 52 |
+
def _process_doc(self, doc):
|
| 53 |
+
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
|
| 54 |
+
out_doc = {
|
| 55 |
+
"query": self.preprocess(doc["activity_label"] + ": " + ctx),
|
| 56 |
+
"choices": [self.preprocess(ending) for ending in doc["endings"]],
|
| 57 |
+
"gold": int(doc["label"]),
|
| 58 |
+
}
|
| 59 |
+
return out_doc
|
| 60 |
+
|
| 61 |
+
@classmethod
|
| 62 |
+
def preprocess(cls, text):
|
| 63 |
+
text = text.strip()
|
| 64 |
+
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
| 65 |
+
text = text.replace(" [title]", ". ")
|
| 66 |
+
text = re.sub("\\[.*?\\]", "", text)
|
| 67 |
+
text = text.replace(" ", " ")
|
| 68 |
+
return text
|
| 69 |
+
|
| 70 |
+
def doc_to_text(self, doc):
|
| 71 |
+
return doc["query"]
|
| 72 |
+
|
| 73 |
+
def should_decontaminate(self):
|
| 74 |
+
return True
|
| 75 |
+
|
| 76 |
+
def doc_to_decontamination_query(self, doc):
|
| 77 |
+
return doc["query"]
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/lambada.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The LAMBADA dataset: Word prediction requiring a broad discourse context∗
|
| 3 |
+
https://arxiv.org/pdf/1606.06031.pdf
|
| 4 |
+
|
| 5 |
+
LAMBADA is a dataset to evaluate the capabilities of computational models for text
|
| 6 |
+
understanding by means of a word prediction task. LAMBADA is a collection of narrative
|
| 7 |
+
passages sharing the characteristic that human subjects are able to guess their last
|
| 8 |
+
word if they are exposed to the whole passage, but not if they only see the last
|
| 9 |
+
sentence preceding the target word. To succeed on LAMBADA, computational models
|
| 10 |
+
cannot simply rely on local context, but must be able to keep track of information
|
| 11 |
+
in the broader discourse.
|
| 12 |
+
|
| 13 |
+
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
|
| 14 |
+
"""
|
| 15 |
+
from lm_eval.base import Task, rf
|
| 16 |
+
from lm_eval.metrics import mean, perplexity
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_CITATION = """
|
| 20 |
+
@misc{
|
| 21 |
+
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
|
| 22 |
+
title={The LAMBADA dataset},
|
| 23 |
+
DOI={10.5281/zenodo.2630551},
|
| 24 |
+
publisher={Zenodo},
|
| 25 |
+
year={2016},
|
| 26 |
+
month={Aug}
|
| 27 |
+
}
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LambadaBase(Task):
|
| 32 |
+
VERSION = None
|
| 33 |
+
|
| 34 |
+
def training_docs(self):
|
| 35 |
+
if self.has_training_docs():
|
| 36 |
+
return self.dataset["train"]
|
| 37 |
+
|
| 38 |
+
def validation_docs(self):
|
| 39 |
+
if self.has_validation_docs():
|
| 40 |
+
return self.dataset["validation"]
|
| 41 |
+
|
| 42 |
+
def test_docs(self):
|
| 43 |
+
if self.has_test_docs():
|
| 44 |
+
return self.dataset["test"]
|
| 45 |
+
|
| 46 |
+
def doc_to_text(self, doc):
|
| 47 |
+
return doc["text"].rsplit(" ", 1)[0]
|
| 48 |
+
|
| 49 |
+
def should_decontaminate(self):
|
| 50 |
+
return True
|
| 51 |
+
|
| 52 |
+
def doc_to_decontamination_query(self, doc):
|
| 53 |
+
return doc["text"]
|
| 54 |
+
|
| 55 |
+
def doc_to_target(self, doc):
|
| 56 |
+
return " " + doc["text"].rsplit(" ", 1)[1]
|
| 57 |
+
|
| 58 |
+
def construct_requests(self, doc, ctx):
|
| 59 |
+
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
|
| 60 |
+
|
| 61 |
+
return ll, is_greedy
|
| 62 |
+
|
| 63 |
+
def process_results(self, doc, results):
|
| 64 |
+
ll, is_greedy = results
|
| 65 |
+
|
| 66 |
+
return {"ppl": ll, "acc": int(is_greedy)}
|
| 67 |
+
|
| 68 |
+
def aggregation(self):
|
| 69 |
+
return {"ppl": perplexity, "acc": mean}
|
| 70 |
+
|
| 71 |
+
def higher_is_better(self):
|
| 72 |
+
return {"ppl": False, "acc": True}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class LambadaStandard(LambadaBase):
|
| 76 |
+
"""The LAMBADA task using the standard original LAMBADA dataset."""
|
| 77 |
+
|
| 78 |
+
VERSION = 0
|
| 79 |
+
DATASET_PATH = "lambada"
|
| 80 |
+
|
| 81 |
+
def has_training_docs(self):
|
| 82 |
+
return False
|
| 83 |
+
|
| 84 |
+
def has_validation_docs(self):
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
def has_test_docs(self):
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class LambadaOpenAI(LambadaBase):
|
| 92 |
+
"""The LAMBADA task using the LAMBADA OpenAI dataset, a modified version of the
|
| 93 |
+
original LAMBADA dataset created by OpenAI for evaluating their GPT-2 model.
|
| 94 |
+
|
| 95 |
+
Reference: https://github.com/openai/gpt-2/issues/131#issuecomment-497136199
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
VERSION = 0
|
| 99 |
+
DATASET_PATH = "EleutherAI/lambada_openai"
|
| 100 |
+
|
| 101 |
+
def has_training_docs(self):
|
| 102 |
+
return False
|
| 103 |
+
|
| 104 |
+
def has_validation_docs(self):
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
def has_test_docs(self):
|
| 108 |
+
return True
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/lambada_multilingual.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The LAMBADA (OpenAI) dataset: Word prediction requiring a broad discourse context∗
|
| 3 |
+
https://arxiv.org/pdf/1606.06031.pdf
|
| 4 |
+
|
| 5 |
+
The LAMBADA OpenAI dataset machine-translated to other languages.
|
| 6 |
+
LAMBADA is a dataset to evaluate the capabilities of computational models for text
|
| 7 |
+
understanding by means of a word prediction task. LAMBADA is a collection of narrative
|
| 8 |
+
passages sharing the characteristic that human subjects are able to guess their last
|
| 9 |
+
word if they are exposed to the whole passage, but not if they only see the last
|
| 10 |
+
sentence preceding the target word. To succeed on LAMBADA, computational models
|
| 11 |
+
cannot simply rely on local context, but must be able to keep track of information
|
| 12 |
+
in the broader discourse.
|
| 13 |
+
|
| 14 |
+
Homepage: https://zenodo.org/record/2630551#.X4Xzn5NKjUI
|
| 15 |
+
|
| 16 |
+
Reference (OpenAI): https://github.com/openai/gpt-2/issues/131#issuecomment-497136199
|
| 17 |
+
"""
|
| 18 |
+
import inspect
|
| 19 |
+
from .lambada import LambadaOpenAI
|
| 20 |
+
from lm_eval.base import rf
|
| 21 |
+
import lm_eval.datasets.lambada_ja.lambada_ja
|
| 22 |
+
from lm_eval.metrics import mean, perplexity
|
| 23 |
+
|
| 24 |
+
_CITATION = """
|
| 25 |
+
@misc{
|
| 26 |
+
author={Paperno, Denis and Kruszewski, Germán and Lazaridou, Angeliki and Pham, Quan Ngoc and Bernardi, Raffaella and Pezzelle, Sandro and Baroni, Marco and Boleda, Gemma and Fernández, Raquel},
|
| 27 |
+
title={The LAMBADA dataset},
|
| 28 |
+
DOI={10.5281/zenodo.2630551},
|
| 29 |
+
publisher={Zenodo},
|
| 30 |
+
year={2016},
|
| 31 |
+
month={Aug}
|
| 32 |
+
}
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LambadaOpenAIMultilingualEnglish(LambadaOpenAI):
|
| 37 |
+
VERSION = 0
|
| 38 |
+
DATASET_NAME = "en"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class LambadaOpenAIMultilingualFrench(LambadaOpenAI):
|
| 42 |
+
VERSION = 0
|
| 43 |
+
DATASET_NAME = "fr"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class LambadaOpenAIMultilingualGerman(LambadaOpenAI):
|
| 47 |
+
VERSION = 0
|
| 48 |
+
DATASET_NAME = "de"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class LambadaOpenAIMultilingualItalian(LambadaOpenAI):
|
| 52 |
+
VERSION = 0
|
| 53 |
+
DATASET_NAME = "it"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LambadaOpenAIMultilingualSpanish(LambadaOpenAI):
|
| 57 |
+
VERSION = 0
|
| 58 |
+
DATASET_NAME = "es"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LambadaOpenAIMultilingualJapanese(LambadaOpenAI):
|
| 62 |
+
VERSION = 0
|
| 63 |
+
DATASET_PATH = inspect.getfile(lm_eval.datasets.lambada_ja.lambada_ja)
|
| 64 |
+
DATASET_NAME = "ja"
|
| 65 |
+
|
| 66 |
+
def test_docs(self):
|
| 67 |
+
# TODO: because all lambda texts are not translated yet, only take 1k translated texts
|
| 68 |
+
# return self.dataset['test']
|
| 69 |
+
texts = [item["text"] for item in self.dataset["test"] if item["text"] != ""][
|
| 70 |
+
:1000
|
| 71 |
+
]
|
| 72 |
+
# remove last 。
|
| 73 |
+
texts = [text[:-1] if text[-1] == "。" else text for text in texts]
|
| 74 |
+
return texts
|
| 75 |
+
# for doc in self.dataset["test"]:
|
| 76 |
+
# yield doc["text"]
|
| 77 |
+
|
| 78 |
+
def doc_to_text(self, doc):
|
| 79 |
+
# return doc.rsplit(" ", 1)[0]
|
| 80 |
+
return doc
|
| 81 |
+
# # using janome
|
| 82 |
+
# try:
|
| 83 |
+
# from janome.tokenizer import Tokenizer
|
| 84 |
+
# t = Tokenizer()
|
| 85 |
+
# except ImportError:
|
| 86 |
+
# raise ImportError("Please install janome first! (`pip install janome`)")
|
| 87 |
+
# words = [token.surface for token in t.tokenize(doc)][:-1]
|
| 88 |
+
# return "".join(words)
|
| 89 |
+
|
| 90 |
+
def doc_to_target(self, doc):
|
| 91 |
+
# return " " + doc["text"].rsplit(" ", 1)[1]
|
| 92 |
+
# # take last token from context
|
| 93 |
+
return "__lasttoken__"
|
| 94 |
+
# # using janome
|
| 95 |
+
# try:
|
| 96 |
+
# from janome.tokenizer import Tokenizer
|
| 97 |
+
# t = Tokenizer()
|
| 98 |
+
# except ImportError:
|
| 99 |
+
# raise ImportError("Please install janome first! (`pip install janome`)")
|
| 100 |
+
# word = [token.surface for token in t.tokenize(doc)][-1]
|
| 101 |
+
# return word
|
| 102 |
+
|
| 103 |
+
def construct_requests(self, doc, ctx):
|
| 104 |
+
ll, is_greedy = rf.loglikelihood(ctx, self.doc_to_target(doc))
|
| 105 |
+
|
| 106 |
+
return ll, is_greedy
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
LANG_CLASSES = [
|
| 110 |
+
LambadaOpenAIMultilingualEnglish,
|
| 111 |
+
LambadaOpenAIMultilingualFrench,
|
| 112 |
+
LambadaOpenAIMultilingualGerman,
|
| 113 |
+
LambadaOpenAIMultilingualItalian,
|
| 114 |
+
LambadaOpenAIMultilingualSpanish,
|
| 115 |
+
LambadaOpenAIMultilingualJapanese,
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def construct_tasks():
|
| 120 |
+
tasks = {}
|
| 121 |
+
for lang_class in LANG_CLASSES:
|
| 122 |
+
tasks[f"lambada_openai_mt_{lang_class.DATASET_NAME}"] = lang_class
|
| 123 |
+
return tasks
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/qa4mre.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
QA4MRE 2011-2013: Overview of Question Answering for Machine Reading Evaluation
|
| 3 |
+
https://www.cs.cmu.edu/~./hovy/papers/13CLEF-QA4MRE.pdf
|
| 4 |
+
|
| 5 |
+
The (English only) QA4MRE challenge which was run as a Lab at CLEF 2011-2013.
|
| 6 |
+
The main objective of this exercise is to develop a methodology for evaluating
|
| 7 |
+
Machine Reading systems through Question Answering and Reading Comprehension
|
| 8 |
+
Tests. Systems should be able to extract knowledge from large volumes of text
|
| 9 |
+
and use this knowledge to answer questions. Four different tasks have been
|
| 10 |
+
organized during these years: Main Task, Processing Modality and Negation for
|
| 11 |
+
Machine Reading, Machine Reading of Biomedical Texts about Alzheimer's disease,
|
| 12 |
+
and Entrance Exam.
|
| 13 |
+
|
| 14 |
+
Homepage: http://nlp.uned.es/clef-qa/repository/qa4mre.php
|
| 15 |
+
"""
|
| 16 |
+
from lm_eval.base import MultipleChoiceTask
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_CITATION = """
|
| 20 |
+
@inproceedings{Peas2013QA4MRE2O,
|
| 21 |
+
title={QA4MRE 2011-2013: Overview of Question Answering for Machine Reading Evaluation},
|
| 22 |
+
author={Anselmo Pe{\~n}as and Eduard H. Hovy and Pamela Forner and {\'A}lvaro Rodrigo and Richard F. E. Sutcliffe and Roser Morante},
|
| 23 |
+
booktitle={CLEF},
|
| 24 |
+
year={2013}
|
| 25 |
+
}
|
| 26 |
+
""" # noqa: W605
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class QA4MRE(MultipleChoiceTask):
|
| 30 |
+
VERSION = 0
|
| 31 |
+
DATASET_PATH = "qa4mre"
|
| 32 |
+
DATASET_NAME = None
|
| 33 |
+
|
| 34 |
+
def has_training_docs(self):
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
def has_validation_docs(self):
|
| 38 |
+
return False
|
| 39 |
+
|
| 40 |
+
def has_test_docs(self):
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
def test_docs(self):
|
| 44 |
+
# `qa4mre` only has train data so we use it for the test docs.
|
| 45 |
+
return map(self._process_doc, self.dataset["train"])
|
| 46 |
+
|
| 47 |
+
def _process_doc(self, doc):
|
| 48 |
+
choices = doc["answer_options"]["answer_str"]
|
| 49 |
+
out_doc = {
|
| 50 |
+
"source": doc["document_str"].strip().replace("'", "'"),
|
| 51 |
+
"query": doc["question_str"],
|
| 52 |
+
"choices": choices,
|
| 53 |
+
"gold": int(doc["correct_answer_id"]) - 1,
|
| 54 |
+
}
|
| 55 |
+
return out_doc
|
| 56 |
+
|
| 57 |
+
def doc_to_text(self, doc):
|
| 58 |
+
return "{}\nQuestion: {}\nAnswer:".format(doc["source"], doc["query"])
|
| 59 |
+
|
| 60 |
+
def should_decontaminate(self):
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
def doc_to_decontamination_query(self, doc):
|
| 64 |
+
return doc["source"] + " " + doc["query"]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class QA4MRE_2011(QA4MRE):
|
| 68 |
+
DATASET_NAME = "2011.main.EN"
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class QA4MRE_2012(QA4MRE):
|
| 72 |
+
DATASET_NAME = "2012.main.EN"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class QA4MRE_2013(QA4MRE):
|
| 76 |
+
DATASET_NAME = "2013.main.EN"
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/squad.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Know What You Don’t Know: Unanswerable Questions for SQuAD
|
| 3 |
+
https://arxiv.org/pdf/1806.03822.pdf
|
| 4 |
+
|
| 5 |
+
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset,
|
| 6 |
+
consisting of questions posed by crowdworkers on a set of Wikipedia articles,
|
| 7 |
+
where the answer to every question is a segment of text, or span, from the
|
| 8 |
+
corresponding reading passage, or the question might be unanswerable.
|
| 9 |
+
SQuAD2.0 combines the 100,000 questions in SQuAD1.1 with over 50,000 unanswerable
|
| 10 |
+
questions written adversarially by crowdworkers to look similar to answerable ones.
|
| 11 |
+
To do well on SQuAD2.0, systems must not only answer questions when possible, but
|
| 12 |
+
also determine when no answer is supported by the paragraph and abstain from answering.
|
| 13 |
+
|
| 14 |
+
Homepage: https://rajpurkar.github.io/SQuAD-explorer/
|
| 15 |
+
"""
|
| 16 |
+
import datasets
|
| 17 |
+
from math import exp
|
| 18 |
+
from lm_eval.base import rf, Task
|
| 19 |
+
from functools import partial
|
| 20 |
+
from packaging import version
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_CITATION = """
|
| 24 |
+
@misc{rajpurkar2018know,
|
| 25 |
+
title={Know What You Don't Know: Unanswerable Questions for SQuAD},
|
| 26 |
+
author={Pranav Rajpurkar and Robin Jia and Percy Liang},
|
| 27 |
+
year={2018},
|
| 28 |
+
eprint={1806.03822},
|
| 29 |
+
archivePrefix={arXiv},
|
| 30 |
+
primaryClass={cs.CL}
|
| 31 |
+
}
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _squad_metric(predictions, references):
|
| 36 |
+
squad_metric = datasets.load_metric("squad_v2")
|
| 37 |
+
return squad_metric.compute(predictions=predictions, references=references)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _squad_agg(key, items):
|
| 41 |
+
predictions, references = zip(*items)
|
| 42 |
+
|
| 43 |
+
return _squad_metric(predictions=predictions, references=references).get(key, 0)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SQuAD2(Task):
|
| 47 |
+
VERSION = 1
|
| 48 |
+
DATASET_PATH = "squad_v2"
|
| 49 |
+
DATASET_NAME = None
|
| 50 |
+
|
| 51 |
+
# HF changed squad on us so we have to make sure we aren't running the old one
|
| 52 |
+
assert version.parse(datasets.__version__) >= version.parse(
|
| 53 |
+
"1.11.0"
|
| 54 |
+
), "datasets v1.11.0 or later required for SQuAD"
|
| 55 |
+
|
| 56 |
+
def has_training_docs(self):
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
def has_validation_docs(self):
|
| 60 |
+
return True
|
| 61 |
+
|
| 62 |
+
def has_test_docs(self):
|
| 63 |
+
return False
|
| 64 |
+
|
| 65 |
+
def training_docs(self):
|
| 66 |
+
return self.dataset["train"]
|
| 67 |
+
|
| 68 |
+
def validation_docs(self):
|
| 69 |
+
return self.dataset["validation"]
|
| 70 |
+
|
| 71 |
+
def doc_to_text(self, doc):
|
| 72 |
+
return (
|
| 73 |
+
"Title: "
|
| 74 |
+
+ doc["title"]
|
| 75 |
+
+ "\n\n"
|
| 76 |
+
+ "Background: "
|
| 77 |
+
+ doc["context"]
|
| 78 |
+
+ "\n\n"
|
| 79 |
+
+ "Question: "
|
| 80 |
+
+ doc["question"]
|
| 81 |
+
+ "\n\n"
|
| 82 |
+
+ "Answer:"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def should_decontaminate(self):
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
def doc_to_decontamination_query(self, doc):
|
| 89 |
+
return doc["context"]
|
| 90 |
+
|
| 91 |
+
def doc_to_target(self, doc):
|
| 92 |
+
answer_list = doc["answers"]["text"]
|
| 93 |
+
if len(answer_list) > 0:
|
| 94 |
+
answer = answer_list[0]
|
| 95 |
+
else:
|
| 96 |
+
answer = "unanswerable"
|
| 97 |
+
return " " + answer
|
| 98 |
+
|
| 99 |
+
def construct_requests(self, doc, ctx):
|
| 100 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 101 |
+
Requests which will be sent to the LM.
|
| 102 |
+
|
| 103 |
+
:param doc:
|
| 104 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 105 |
+
:param ctx: str
|
| 106 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 107 |
+
language description, as well as the few shot examples, and the question
|
| 108 |
+
part of the document for `doc`.
|
| 109 |
+
"""
|
| 110 |
+
continuation = rf.greedy_until(ctx, ["\n"])
|
| 111 |
+
is_unanswerable = rf.loglikelihood(ctx, " " + "unanswerable")
|
| 112 |
+
return continuation, is_unanswerable
|
| 113 |
+
|
| 114 |
+
def process_results(self, doc, results):
|
| 115 |
+
"""Take a single document and the LM results and evaluates, returning a
|
| 116 |
+
dict where keys are the names of submetrics and values are the values of
|
| 117 |
+
the metric for that one document
|
| 118 |
+
|
| 119 |
+
:param doc:
|
| 120 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 121 |
+
:param results:
|
| 122 |
+
The results of the requests created in construct_requests.
|
| 123 |
+
"""
|
| 124 |
+
continuation, (logprob_unanswerable, _) = results
|
| 125 |
+
|
| 126 |
+
no_answer_probability = exp(logprob_unanswerable)
|
| 127 |
+
|
| 128 |
+
predictions = {
|
| 129 |
+
"id": doc["id"],
|
| 130 |
+
"prediction_text": continuation,
|
| 131 |
+
"no_answer_probability": no_answer_probability,
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
references = {
|
| 135 |
+
"id": doc["id"],
|
| 136 |
+
"answers": doc["answers"],
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"exact": (
|
| 141 |
+
predictions,
|
| 142 |
+
references,
|
| 143 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
| 144 |
+
"f1": (
|
| 145 |
+
predictions,
|
| 146 |
+
references,
|
| 147 |
+
), # The F-score of predicted tokens versus the gold answer
|
| 148 |
+
"HasAns_exact": (
|
| 149 |
+
predictions,
|
| 150 |
+
references,
|
| 151 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
| 152 |
+
"HasAns_f1": (
|
| 153 |
+
predictions,
|
| 154 |
+
references,
|
| 155 |
+
), # The F-score of predicted tokens versus the gold answer
|
| 156 |
+
"NoAns_exact": (
|
| 157 |
+
predictions,
|
| 158 |
+
references,
|
| 159 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
| 160 |
+
"NoAns_f1": (
|
| 161 |
+
predictions,
|
| 162 |
+
references,
|
| 163 |
+
), # The F-score of predicted tokens versus the gold answer
|
| 164 |
+
"best_exact": (
|
| 165 |
+
predictions,
|
| 166 |
+
references,
|
| 167 |
+
), # Best exact match (with varying threshold)
|
| 168 |
+
"best_f1": (predictions, references), # Best F1 (with varying threshold)
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
def aggregation(self):
|
| 172 |
+
"""
|
| 173 |
+
:returns: {str: [float] -> float}
|
| 174 |
+
A dictionary where keys are the names of submetrics and values are
|
| 175 |
+
functions that aggregate a list of metrics
|
| 176 |
+
"""
|
| 177 |
+
return {
|
| 178 |
+
"exact": partial(
|
| 179 |
+
_squad_agg, "exact"
|
| 180 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
| 181 |
+
"f1": partial(
|
| 182 |
+
_squad_agg, "f1"
|
| 183 |
+
), # The F-score of predicted tokens versus the gold answer
|
| 184 |
+
"HasAns_exact": partial(
|
| 185 |
+
_squad_agg, "HasAns_exact"
|
| 186 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
| 187 |
+
"HasAns_f1": partial(
|
| 188 |
+
_squad_agg, "HasAns_f1"
|
| 189 |
+
), # The F-score of predicted tokens versus the gold answer
|
| 190 |
+
"NoAns_exact": partial(
|
| 191 |
+
_squad_agg, "NoAns_exact"
|
| 192 |
+
), # Exact match (the normalized answer exactly match the gold answer)
|
| 193 |
+
"NoAns_f1": partial(
|
| 194 |
+
_squad_agg, "NoAns_f1"
|
| 195 |
+
), # The F-score of predicted tokens versus the gold answer
|
| 196 |
+
"best_exact": partial(
|
| 197 |
+
_squad_agg, "best_exact"
|
| 198 |
+
), # Best exact match (with varying threshold)
|
| 199 |
+
"best_f1": partial(
|
| 200 |
+
_squad_agg, "best_f1"
|
| 201 |
+
), # Best F1 (with varying threshold)
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
def higher_is_better(self):
|
| 205 |
+
"""
|
| 206 |
+
:returns: {str: bool}
|
| 207 |
+
A dictionary where keys are the names of submetrics and values are
|
| 208 |
+
whether a higher value of the submetric is better
|
| 209 |
+
"""
|
| 210 |
+
return {
|
| 211 |
+
"exact": True, # Exact match (the normalized answer exactly match the gold answer)
|
| 212 |
+
"f1": True, # The F-score of predicted tokens versus the gold answer
|
| 213 |
+
"HasAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
|
| 214 |
+
"HasAns_f1": True, # The F-score of predicted tokens versus the gold answer
|
| 215 |
+
"NoAns_exact": True, # Exact match (the normalized answer exactly match the gold answer)
|
| 216 |
+
"NoAns_f1": True, # The F-score of predicted tokens versus the gold answer
|
| 217 |
+
"best_exact": True, # Best exact match (with varying threshold)
|
| 218 |
+
"best_f1": True, # Best F1 (with varying threshold)
|
| 219 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/superglue.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems
|
| 3 |
+
https://w4ngatang.github.io/static/papers/superglue.pdf
|
| 4 |
+
|
| 5 |
+
SuperGLUE is a benchmark styled after GLUE with a new set of more difficult language
|
| 6 |
+
understanding tasks.
|
| 7 |
+
|
| 8 |
+
Homepage: https://super.gluebenchmark.com/
|
| 9 |
+
|
| 10 |
+
TODO: WSC requires free-form generation.
|
| 11 |
+
"""
|
| 12 |
+
import numpy as np
|
| 13 |
+
import sklearn
|
| 14 |
+
import transformers.data.metrics.squad_metrics as squad_metrics
|
| 15 |
+
from lm_eval.base import rf, Task
|
| 16 |
+
from lm_eval.metrics import mean, acc_all, metric_max_over_ground_truths, yesno
|
| 17 |
+
from lm_eval.utils import general_detokenize
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
_CITATION = """
|
| 21 |
+
@inproceedings{NEURIPS2019_4496bf24,
|
| 22 |
+
author = {Wang, Alex and Pruksachatkun, Yada and Nangia, Nikita and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel},
|
| 23 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
| 24 |
+
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
|
| 25 |
+
pages = {},
|
| 26 |
+
publisher = {Curran Associates, Inc.},
|
| 27 |
+
title = {SuperGLUE: A Stickier Benchmark for General-Purpose Language Understanding Systems},
|
| 28 |
+
url = {https://proceedings.neurips.cc/paper/2019/file/4496bf24afe7fab6f046bf4923da8de6-Paper.pdf},
|
| 29 |
+
volume = {32},
|
| 30 |
+
year = {2019}
|
| 31 |
+
}
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BoolQ(Task):
|
| 36 |
+
VERSION = 1
|
| 37 |
+
DATASET_PATH = "super_glue"
|
| 38 |
+
DATASET_NAME = "boolq"
|
| 39 |
+
|
| 40 |
+
def has_training_docs(self):
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
def has_validation_docs(self):
|
| 44 |
+
return True
|
| 45 |
+
|
| 46 |
+
def has_test_docs(self):
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
def training_docs(self):
|
| 50 |
+
if self._training_docs is None:
|
| 51 |
+
self._training_docs = list(self.dataset["train"])
|
| 52 |
+
return self._training_docs
|
| 53 |
+
|
| 54 |
+
def validation_docs(self):
|
| 55 |
+
return self.dataset["validation"]
|
| 56 |
+
|
| 57 |
+
def doc_to_text(self, doc):
|
| 58 |
+
return f"{doc['passage']}\nQuestion: {doc['question']}?\nAnswer:"
|
| 59 |
+
|
| 60 |
+
def should_decontaminate(self):
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
def doc_to_decontamination_query(self, doc):
|
| 64 |
+
return doc["passage"]
|
| 65 |
+
|
| 66 |
+
def doc_to_target(self, doc):
|
| 67 |
+
return " " + yesno(doc["label"])
|
| 68 |
+
|
| 69 |
+
def construct_requests(self, doc, ctx):
|
| 70 |
+
|
| 71 |
+
ll_yes, _ = rf.loglikelihood(ctx, " yes")
|
| 72 |
+
ll_no, _ = rf.loglikelihood(ctx, " no")
|
| 73 |
+
|
| 74 |
+
return ll_yes, ll_no
|
| 75 |
+
|
| 76 |
+
def process_results(self, doc, results):
|
| 77 |
+
ll_yes, ll_no = results
|
| 78 |
+
gold = doc["label"]
|
| 79 |
+
|
| 80 |
+
acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
|
| 81 |
+
|
| 82 |
+
return {"acc": acc}
|
| 83 |
+
|
| 84 |
+
def higher_is_better(self):
|
| 85 |
+
return {"acc": True}
|
| 86 |
+
|
| 87 |
+
def aggregation(self):
|
| 88 |
+
return {"acc": mean}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CommitmentBank(Task):
|
| 92 |
+
VERSION = 1
|
| 93 |
+
DATASET_PATH = "super_glue"
|
| 94 |
+
DATASET_NAME = "cb"
|
| 95 |
+
|
| 96 |
+
def has_training_docs(self):
|
| 97 |
+
return True
|
| 98 |
+
|
| 99 |
+
def has_validation_docs(self):
|
| 100 |
+
return True
|
| 101 |
+
|
| 102 |
+
def has_test_docs(self):
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def training_docs(self):
|
| 106 |
+
if self._training_docs is None:
|
| 107 |
+
self._training_docs = list(self.dataset["train"])
|
| 108 |
+
return self._training_docs
|
| 109 |
+
|
| 110 |
+
def validation_docs(self):
|
| 111 |
+
return self.dataset["validation"]
|
| 112 |
+
|
| 113 |
+
def doc_to_text(self, doc):
|
| 114 |
+
return "{}\nQuestion: {}. True, False or Neither?\nAnswer:".format(
|
| 115 |
+
doc["premise"],
|
| 116 |
+
doc["hypothesis"],
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def doc_to_target(self, doc):
|
| 120 |
+
# True = entailment
|
| 121 |
+
# False = contradiction
|
| 122 |
+
# Neither = neutral
|
| 123 |
+
return " {}".format({0: "True", 1: "False", 2: "Neither"}[doc["label"]])
|
| 124 |
+
|
| 125 |
+
def construct_requests(self, doc, ctx):
|
| 126 |
+
ll_true, _ = rf.loglikelihood(ctx, " True")
|
| 127 |
+
ll_false, _ = rf.loglikelihood(ctx, " False")
|
| 128 |
+
ll_neither, _ = rf.loglikelihood(ctx, " Neither")
|
| 129 |
+
|
| 130 |
+
return ll_true, ll_false, ll_neither
|
| 131 |
+
|
| 132 |
+
def process_results(self, doc, results):
|
| 133 |
+
gold = doc["label"]
|
| 134 |
+
pred = np.argmax(results)
|
| 135 |
+
acc = 1.0 if pred == gold else 0.0
|
| 136 |
+
|
| 137 |
+
return {"acc": acc, "f1": (pred, gold)}
|
| 138 |
+
|
| 139 |
+
def higher_is_better(self):
|
| 140 |
+
return {"acc": True, "f1": True}
|
| 141 |
+
|
| 142 |
+
@classmethod
|
| 143 |
+
def cb_multi_fi(cls, items):
|
| 144 |
+
preds, golds = zip(*items)
|
| 145 |
+
preds = np.array(preds)
|
| 146 |
+
golds = np.array(golds)
|
| 147 |
+
f11 = sklearn.metrics.f1_score(y_true=golds == 0, y_pred=preds == 0)
|
| 148 |
+
f12 = sklearn.metrics.f1_score(y_true=golds == 1, y_pred=preds == 1)
|
| 149 |
+
f13 = sklearn.metrics.f1_score(y_true=golds == 2, y_pred=preds == 2)
|
| 150 |
+
avg_f1 = mean([f11, f12, f13])
|
| 151 |
+
return avg_f1
|
| 152 |
+
|
| 153 |
+
def aggregation(self):
|
| 154 |
+
return {
|
| 155 |
+
"acc": mean,
|
| 156 |
+
"f1": self.cb_multi_fi,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class Copa(Task):
|
| 161 |
+
VERSION = 0
|
| 162 |
+
DATASET_PATH = "super_glue"
|
| 163 |
+
DATASET_NAME = "copa"
|
| 164 |
+
|
| 165 |
+
def has_training_docs(self):
|
| 166 |
+
return True
|
| 167 |
+
|
| 168 |
+
def has_validation_docs(self):
|
| 169 |
+
return True
|
| 170 |
+
|
| 171 |
+
def has_test_docs(self):
|
| 172 |
+
return False
|
| 173 |
+
|
| 174 |
+
def training_docs(self):
|
| 175 |
+
if self._training_docs is None:
|
| 176 |
+
self._training_docs = list(self.dataset["train"])
|
| 177 |
+
return self._training_docs
|
| 178 |
+
|
| 179 |
+
def validation_docs(self):
|
| 180 |
+
return self.dataset["validation"]
|
| 181 |
+
|
| 182 |
+
def doc_to_text(self, doc):
|
| 183 |
+
# Drop the period
|
| 184 |
+
connector = {
|
| 185 |
+
"cause": "because",
|
| 186 |
+
"effect": "therefore",
|
| 187 |
+
}[doc["question"]]
|
| 188 |
+
return doc["premise"].strip()[:-1] + f" {connector}"
|
| 189 |
+
|
| 190 |
+
def doc_to_target(self, doc):
|
| 191 |
+
correct_choice = doc["choice1"] if doc["label"] == 0 else doc["choice2"]
|
| 192 |
+
# Connect the sentences
|
| 193 |
+
return " " + self.convert_choice(correct_choice)
|
| 194 |
+
|
| 195 |
+
def construct_requests(self, doc, ctx):
|
| 196 |
+
choice1 = " " + self.convert_choice(doc["choice1"])
|
| 197 |
+
choice2 = " " + self.convert_choice(doc["choice2"])
|
| 198 |
+
|
| 199 |
+
ll_choice1, _ = rf.loglikelihood(ctx, choice1)
|
| 200 |
+
ll_choice2, _ = rf.loglikelihood(ctx, choice2)
|
| 201 |
+
|
| 202 |
+
return ll_choice1, ll_choice2
|
| 203 |
+
|
| 204 |
+
def process_results(self, doc, results):
|
| 205 |
+
gold = doc["label"]
|
| 206 |
+
pred = np.argmax(results)
|
| 207 |
+
acc = 1.0 if pred == gold else 0.0
|
| 208 |
+
|
| 209 |
+
return {"acc": acc}
|
| 210 |
+
|
| 211 |
+
def higher_is_better(self):
|
| 212 |
+
return {"acc": True}
|
| 213 |
+
|
| 214 |
+
def aggregation(self):
|
| 215 |
+
return {"acc": mean}
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def convert_choice(choice):
|
| 219 |
+
return choice[0].lower() + choice[1:]
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class MultiRC(Task):
|
| 223 |
+
VERSION = 1
|
| 224 |
+
DATASET_PATH = "super_glue"
|
| 225 |
+
DATASET_NAME = "multirc"
|
| 226 |
+
|
| 227 |
+
def has_training_docs(self):
|
| 228 |
+
return True
|
| 229 |
+
|
| 230 |
+
def has_validation_docs(self):
|
| 231 |
+
return True
|
| 232 |
+
|
| 233 |
+
def has_test_docs(self):
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
def training_docs(self):
|
| 237 |
+
if self._training_docs is None:
|
| 238 |
+
self._training_docs = list(self.dataset["train"])
|
| 239 |
+
return self._training_docs
|
| 240 |
+
|
| 241 |
+
def validation_docs(self):
|
| 242 |
+
return self.dataset["validation"]
|
| 243 |
+
|
| 244 |
+
def doc_to_text(self, doc):
|
| 245 |
+
return f"{doc['paragraph']}\nQuestion: {doc['question']}\nAnswer:"
|
| 246 |
+
|
| 247 |
+
def doc_to_target(self, doc):
|
| 248 |
+
return " " + self.format_answer(answer=doc["answer"], label=doc["label"])
|
| 249 |
+
|
| 250 |
+
@staticmethod
|
| 251 |
+
def format_answer(answer, label):
|
| 252 |
+
label_str = "yes" if label else "no"
|
| 253 |
+
return f"{answer}\nIs the answer correct? {label_str}"
|
| 254 |
+
|
| 255 |
+
def construct_requests(self, doc, ctx):
|
| 256 |
+
true_choice = self.format_answer(answer=doc["answer"], label=True)
|
| 257 |
+
false_choice = self.format_answer(answer=doc["answer"], label=False)
|
| 258 |
+
|
| 259 |
+
ll_true_choice, _ = rf.loglikelihood(ctx, f" {true_choice}")
|
| 260 |
+
ll_false_choice, _ = rf.loglikelihood(ctx, f" {false_choice}")
|
| 261 |
+
|
| 262 |
+
return ll_true_choice, ll_false_choice
|
| 263 |
+
|
| 264 |
+
def process_results(self, doc, results):
|
| 265 |
+
ll_true_choice, ll_false_choice = results
|
| 266 |
+
pred = ll_true_choice > ll_false_choice
|
| 267 |
+
return {"acc": (pred, doc)}
|
| 268 |
+
|
| 269 |
+
def higher_is_better(self):
|
| 270 |
+
return {"acc": True}
|
| 271 |
+
|
| 272 |
+
def aggregation(self):
|
| 273 |
+
return {"acc": acc_all}
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
class ReCoRD(Task):
|
| 277 |
+
VERSION = 0
|
| 278 |
+
DATASET_PATH = "super_glue"
|
| 279 |
+
DATASET_NAME = "record"
|
| 280 |
+
|
| 281 |
+
def has_training_docs(self):
|
| 282 |
+
return True
|
| 283 |
+
|
| 284 |
+
def has_validation_docs(self):
|
| 285 |
+
return True
|
| 286 |
+
|
| 287 |
+
def has_test_docs(self):
|
| 288 |
+
return False
|
| 289 |
+
|
| 290 |
+
def training_docs(self):
|
| 291 |
+
# In ReCoRD, each doc manifests multiple "examples" in the context of few shot example packing.
|
| 292 |
+
# Each doc consists of multiple answer candidates, each of which is scored yes/no.
|
| 293 |
+
if self._training_docs is None:
|
| 294 |
+
self._training_docs = []
|
| 295 |
+
for doc in self.dataset["train"]:
|
| 296 |
+
self._training_docs.append(self._process_doc(doc))
|
| 297 |
+
return self._training_docs
|
| 298 |
+
|
| 299 |
+
def validation_docs(self):
|
| 300 |
+
# See: training_docs
|
| 301 |
+
for doc in self.dataset["validation"]:
|
| 302 |
+
yield self._process_doc(doc)
|
| 303 |
+
|
| 304 |
+
@classmethod
|
| 305 |
+
def _process_doc(cls, doc):
|
| 306 |
+
return {
|
| 307 |
+
"passage": doc["passage"],
|
| 308 |
+
"query": doc["query"],
|
| 309 |
+
"entities": sorted(list(set(doc["entities"]))),
|
| 310 |
+
"answers": sorted(list(set(doc["answers"]))),
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
def doc_to_text(self, doc):
|
| 314 |
+
initial_text, *highlights = doc["passage"].strip().split("\n@highlight\n")
|
| 315 |
+
text = initial_text + "\n\n"
|
| 316 |
+
for highlight in highlights:
|
| 317 |
+
text += f" - {highlight}.\n"
|
| 318 |
+
return text
|
| 319 |
+
|
| 320 |
+
@classmethod
|
| 321 |
+
def format_answer(cls, query, entity):
|
| 322 |
+
return f" - {query}".replace("@placeholder", entity)
|
| 323 |
+
|
| 324 |
+
def doc_to_target(self, doc):
|
| 325 |
+
# We only output the first correct entity in a doc
|
| 326 |
+
return self.format_answer(query=doc["query"], entity=doc["answers"][0])
|
| 327 |
+
|
| 328 |
+
def construct_requests(self, doc, ctx):
|
| 329 |
+
requests = [
|
| 330 |
+
rf.loglikelihood(ctx, self.format_answer(query=doc["query"], entity=entity))
|
| 331 |
+
for entity in doc["entities"]
|
| 332 |
+
]
|
| 333 |
+
return requests
|
| 334 |
+
|
| 335 |
+
def process_results(self, doc, results):
|
| 336 |
+
# ReCoRD's evaluation is actually deceptively simple:
|
| 337 |
+
# - Pick the maximum likelihood prediction entity
|
| 338 |
+
# - Evaluate the accuracy and token F1 PER EXAMPLE
|
| 339 |
+
# - Average over all examples
|
| 340 |
+
max_idx = np.argmax(np.array([result[0] for result in results]))
|
| 341 |
+
|
| 342 |
+
prediction = doc["entities"][max_idx]
|
| 343 |
+
gold_label_set = doc["answers"]
|
| 344 |
+
f1 = metric_max_over_ground_truths(
|
| 345 |
+
squad_metrics.compute_f1, prediction, gold_label_set
|
| 346 |
+
)
|
| 347 |
+
em = metric_max_over_ground_truths(
|
| 348 |
+
squad_metrics.compute_exact, prediction, gold_label_set
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
"f1": f1,
|
| 353 |
+
"em": em,
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
def higher_is_better(self):
|
| 357 |
+
return {
|
| 358 |
+
"f1": True,
|
| 359 |
+
"em": True,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
def aggregation(self):
|
| 363 |
+
return {
|
| 364 |
+
"f1": mean,
|
| 365 |
+
"em": mean,
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class WordsInContext(Task):
|
| 370 |
+
VERSION = 0
|
| 371 |
+
DATASET_PATH = "super_glue"
|
| 372 |
+
DATASET_NAME = "wic"
|
| 373 |
+
|
| 374 |
+
def has_training_docs(self):
|
| 375 |
+
return True
|
| 376 |
+
|
| 377 |
+
def has_validation_docs(self):
|
| 378 |
+
return True
|
| 379 |
+
|
| 380 |
+
def has_test_docs(self):
|
| 381 |
+
return False
|
| 382 |
+
|
| 383 |
+
def training_docs(self):
|
| 384 |
+
if self._training_docs is None:
|
| 385 |
+
self._training_docs = list(self.dataset["train"])
|
| 386 |
+
return self._training_docs
|
| 387 |
+
|
| 388 |
+
def validation_docs(self):
|
| 389 |
+
return self.dataset["validation"]
|
| 390 |
+
|
| 391 |
+
def doc_to_text(self, doc):
|
| 392 |
+
return (
|
| 393 |
+
"Sentence 1: {}\nSentence 2: {}\nQuestion: Is the word '{}' used in the same way in the"
|
| 394 |
+
" two sentences above?\nAnswer:".format(
|
| 395 |
+
doc["sentence1"],
|
| 396 |
+
doc["sentence2"],
|
| 397 |
+
doc["sentence1"][doc["start1"] : doc["end1"]],
|
| 398 |
+
)
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
def doc_to_target(self, doc):
|
| 402 |
+
return " {}".format({0: "no", 1: "yes"}[doc["label"]])
|
| 403 |
+
|
| 404 |
+
def construct_requests(self, doc, ctx):
|
| 405 |
+
ll_yes, _ = rf.loglikelihood(ctx, " yes")
|
| 406 |
+
ll_no, _ = rf.loglikelihood(ctx, " no")
|
| 407 |
+
|
| 408 |
+
return ll_yes, ll_no
|
| 409 |
+
|
| 410 |
+
def process_results(self, doc, results):
|
| 411 |
+
ll_yes, ll_no = results
|
| 412 |
+
gold = doc["label"]
|
| 413 |
+
|
| 414 |
+
acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
|
| 415 |
+
|
| 416 |
+
return {"acc": acc}
|
| 417 |
+
|
| 418 |
+
def higher_is_better(self):
|
| 419 |
+
return {"acc": True}
|
| 420 |
+
|
| 421 |
+
def aggregation(self):
|
| 422 |
+
return {"acc": mean}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
class SGWinogradSchemaChallenge(Task):
|
| 426 |
+
VERSION = 0
|
| 427 |
+
# Note: This implementation differs from Fig G.32 because this is the SuperGLUE,
|
| 428 |
+
# binary version of the task.
|
| 429 |
+
DATASET_PATH = "super_glue"
|
| 430 |
+
DATASET_NAME = "wsc"
|
| 431 |
+
|
| 432 |
+
def has_training_docs(self):
|
| 433 |
+
return True
|
| 434 |
+
|
| 435 |
+
def has_validation_docs(self):
|
| 436 |
+
return True
|
| 437 |
+
|
| 438 |
+
def has_test_docs(self):
|
| 439 |
+
return False
|
| 440 |
+
|
| 441 |
+
def training_docs(self):
|
| 442 |
+
if self.has_training_docs():
|
| 443 |
+
if self._training_docs is None:
|
| 444 |
+
# GPT-3 Paper's format only uses positive examples for fewshot "training"
|
| 445 |
+
self._training_docs = [
|
| 446 |
+
doc for doc in self.dataset["train"] if doc["label"]
|
| 447 |
+
]
|
| 448 |
+
return self._training_docs
|
| 449 |
+
|
| 450 |
+
def validation_docs(self):
|
| 451 |
+
return self.dataset["validation"]
|
| 452 |
+
|
| 453 |
+
def doc_to_text(self, doc):
|
| 454 |
+
raw_passage = doc["text"]
|
| 455 |
+
# NOTE: HuggingFace span indices are word-based not character-based.
|
| 456 |
+
pre = " ".join(raw_passage.split()[: doc["span2_index"]])
|
| 457 |
+
post = raw_passage[len(pre) + len(doc["span2_text"]) + 1 :]
|
| 458 |
+
passage = general_detokenize(pre + " *{}*".format(doc["span2_text"]) + post)
|
| 459 |
+
noun = doc["span1_text"]
|
| 460 |
+
pronoun = doc["span2_text"]
|
| 461 |
+
text = (
|
| 462 |
+
f"Passage: {passage}\n"
|
| 463 |
+
+ f'Question: In the passage above, does the pronoun "*{pronoun}*" refer to "*{noun}*"?\n'
|
| 464 |
+
+ "Answer:"
|
| 465 |
+
)
|
| 466 |
+
return text
|
| 467 |
+
|
| 468 |
+
def doc_to_target(self, doc):
|
| 469 |
+
return " " + yesno(doc["label"])
|
| 470 |
+
|
| 471 |
+
def construct_requests(self, doc, ctx):
|
| 472 |
+
|
| 473 |
+
ll_yes, _ = rf.loglikelihood(ctx, " yes")
|
| 474 |
+
ll_no, _ = rf.loglikelihood(ctx, " no")
|
| 475 |
+
|
| 476 |
+
return ll_yes, ll_no
|
| 477 |
+
|
| 478 |
+
def process_results(self, doc, results):
|
| 479 |
+
ll_yes, ll_no = results
|
| 480 |
+
gold = doc["label"]
|
| 481 |
+
|
| 482 |
+
acc = 1.0 if (ll_yes > ll_no) == gold else 0.0
|
| 483 |
+
|
| 484 |
+
return {"acc": acc}
|
| 485 |
+
|
| 486 |
+
def higher_is_better(self):
|
| 487 |
+
return {"acc": True}
|
| 488 |
+
|
| 489 |
+
def aggregation(self):
|
| 490 |
+
return {"acc": mean}
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/tasks/translation.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
NOTE: This file implements translation tasks using datasets from WMT conferences,
|
| 3 |
+
provided by sacrebleu. Traditionally they are evaluated with BLEU scores. TER
|
| 4 |
+
and CHRF are other options.
|
| 5 |
+
|
| 6 |
+
We defer citations and descriptions of the many translations tasks used
|
| 7 |
+
here to the SacreBLEU repo from which we've obtained the datasets:
|
| 8 |
+
https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/dataset.py
|
| 9 |
+
|
| 10 |
+
Homepage: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/dataset.py
|
| 11 |
+
"""
|
| 12 |
+
import pycountry
|
| 13 |
+
from pprint import pprint
|
| 14 |
+
from sacrebleu import sacrebleu
|
| 15 |
+
from lm_eval import metrics
|
| 16 |
+
from lm_eval.base import Task, rf
|
| 17 |
+
from typing import List
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import nagisa
|
| 21 |
+
|
| 22 |
+
HAS_NAGISA = True
|
| 23 |
+
except ImportError:
|
| 24 |
+
HAS_NAGISA = False
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import jieba
|
| 28 |
+
|
| 29 |
+
HAS_JIEBA = True
|
| 30 |
+
except ImportError:
|
| 31 |
+
HAS_JIEBA = False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_CITATION = """
|
| 35 |
+
@inproceedings{post-2018-call,
|
| 36 |
+
title = "A Call for Clarity in Reporting {BLEU} Scores",
|
| 37 |
+
author = "Post, Matt",
|
| 38 |
+
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
|
| 39 |
+
month = oct,
|
| 40 |
+
year = "2018",
|
| 41 |
+
address = "Belgium, Brussels",
|
| 42 |
+
publisher = "Association for Computational Linguistics",
|
| 43 |
+
url = "https://www.aclweb.org/anthology/W18-6319",
|
| 44 |
+
pages = "186--191",
|
| 45 |
+
}
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
sacrebleu_datasets = sacrebleu.DATASETS
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def create_tasks_from_benchmarks(benchmark_dict):
|
| 53 |
+
"""Creates a dictionary of tasks from a dict
|
| 54 |
+
:param benchmark_dict: { dataset: [lang_pair, ...], }
|
| 55 |
+
:return: {task_name: task}
|
| 56 |
+
e.g. {wmt14-fr-en: Task, wmt16-de-en: Task}
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def version_of(dataset, language_pair):
|
| 60 |
+
if language_pair[-2:] in ["zh", "ja"]:
|
| 61 |
+
return 1 # changed to use jieba/nagisa
|
| 62 |
+
return 0
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
f"{dataset}-{language_pair}": create_translation_task(
|
| 66 |
+
dataset, language_pair, version_of(dataset, language_pair)
|
| 67 |
+
)
|
| 68 |
+
for dataset, language_pairs in benchmark_dict.items()
|
| 69 |
+
for language_pair in language_pairs
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
########################################
|
| 74 |
+
# Language Specifics
|
| 75 |
+
########################################
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def zh_split(zh_text: List[str]) -> List[str]:
|
| 79 |
+
"""Chinese splitting"""
|
| 80 |
+
if not HAS_JIEBA:
|
| 81 |
+
raise ImportError(
|
| 82 |
+
"Chinese text splitting requires the `jieba` package. "
|
| 83 |
+
"Please install it with:\npip install jieba"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
return [" ".join(jieba.cut(txt.strip())) for txt in zh_text]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def ja_split(ja_text: List[str]) -> List[str]:
|
| 90 |
+
"""Japanese splitting"""
|
| 91 |
+
if not HAS_NAGISA:
|
| 92 |
+
raise ImportError(
|
| 93 |
+
"Japanese text splitting requires the `nagisa` package. "
|
| 94 |
+
"Please install it with:\npip install nagisa"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return [" ".join(nagisa.tagging(txt.strip()).words) for txt in ja_text]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
NO_SPACE_LANG = {"zh": zh_split, "ja": ja_split}
|
| 101 |
+
|
| 102 |
+
########################################
|
| 103 |
+
# Tasks
|
| 104 |
+
########################################
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def create_translation_task(dataset, language_pair, version=0):
|
| 108 |
+
class TranslationTask(GeneralTranslationTask):
|
| 109 |
+
VERSION = version
|
| 110 |
+
|
| 111 |
+
def __init__(self):
|
| 112 |
+
super().__init__(dataset, language_pair)
|
| 113 |
+
|
| 114 |
+
return TranslationTask
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class GeneralTranslationTask(Task):
|
| 118 |
+
VERSION = 0
|
| 119 |
+
|
| 120 |
+
# e.g. ("wmt14", "fr-en")
|
| 121 |
+
def __init__(self, sacrebleu_dataset, sacrebleu_language_pair=None):
|
| 122 |
+
self.sacrebleu_dataset = sacrebleu_dataset
|
| 123 |
+
self.sacrebleu_language_pair = sacrebleu_language_pair
|
| 124 |
+
self.src_file = self.ref_file = self.src_data = self.ref_data = None
|
| 125 |
+
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
def download(self, data_dir=None, cache_dir=None, download_mode=None):
|
| 129 |
+
# This caches in the users home dir automatically
|
| 130 |
+
self.src_file, self.ref_file = sacrebleu.download_test_set(
|
| 131 |
+
self.sacrebleu_dataset, self.sacrebleu_language_pair
|
| 132 |
+
)
|
| 133 |
+
self.src_data, self.ref_data = [
|
| 134 |
+
[line.rstrip() for line in sacrebleu.smart_open(file)]
|
| 135 |
+
for file in (self.src_file, self.ref_file)
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
def has_training_docs(self):
|
| 139 |
+
"""Whether the task has a training set"""
|
| 140 |
+
# TODO In the future we could be more discerning. Some more recent tests have train and dev sets
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
def has_validation_docs(self):
|
| 144 |
+
"""Whether the task has a validation set"""
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
def has_test_docs(self):
|
| 148 |
+
"""Whether the task has a test set"""
|
| 149 |
+
return True
|
| 150 |
+
|
| 151 |
+
def test_docs(self):
|
| 152 |
+
"""
|
| 153 |
+
:return: Iterable[obj]
|
| 154 |
+
A iterable of any object, that doc_to_text can handle
|
| 155 |
+
"""
|
| 156 |
+
return [
|
| 157 |
+
{"src": src, "ref": ref} for src, ref in zip(self.src_data, self.ref_data)
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
def doc_to_text(self, doc):
|
| 161 |
+
language_codes = self.sacrebleu_language_pair.split("-")
|
| 162 |
+
src_lang = code_to_language(language_codes[0])
|
| 163 |
+
tar_lang = code_to_language(language_codes[1])
|
| 164 |
+
return f"{src_lang} phrase: " + doc["src"] + f"\n{tar_lang} phrase:"
|
| 165 |
+
|
| 166 |
+
def should_decontaminate(self):
|
| 167 |
+
return True
|
| 168 |
+
|
| 169 |
+
def doc_to_decontamination_query(self, doc):
|
| 170 |
+
return doc["src"]
|
| 171 |
+
|
| 172 |
+
def doc_to_target(self, doc):
|
| 173 |
+
# This shows a single target, though there may be multiple targets in a lang test
|
| 174 |
+
return " " + doc["ref"] if isinstance(doc["ref"], str) else doc["ref"][0]
|
| 175 |
+
|
| 176 |
+
def construct_requests(self, doc, ctx):
|
| 177 |
+
"""Uses RequestFactory to construct Requests and returns an iterable of
|
| 178 |
+
Requests which will be sent to the LM.
|
| 179 |
+
|
| 180 |
+
:param doc:
|
| 181 |
+
The document as returned from training_docs, validation_docs, or test_docs.
|
| 182 |
+
:param ctx: str
|
| 183 |
+
The context string, generated by fewshot_context. This includes the natural
|
| 184 |
+
language description, as well as the few shot examples, and the question
|
| 185 |
+
part of the document for `doc`.
|
| 186 |
+
"""
|
| 187 |
+
return rf.greedy_until(ctx, ["\n"])
|
| 188 |
+
|
| 189 |
+
def process_results(self, doc, results):
|
| 190 |
+
# Add spaces between words for BLEU score calculation of target languages like Chinese
|
| 191 |
+
tar_lang_code = self.sacrebleu_language_pair.split("-")[-1]
|
| 192 |
+
if tar_lang_code in NO_SPACE_LANG:
|
| 193 |
+
doc["ref"] = NO_SPACE_LANG[tar_lang_code]([doc["ref"]])[0]
|
| 194 |
+
results = NO_SPACE_LANG[tar_lang_code](results)
|
| 195 |
+
|
| 196 |
+
# These metrics are corpus-level not sentence level, so we'll hide the
|
| 197 |
+
# results in this dict and compute the corpus score in the aggregate method
|
| 198 |
+
ref_pred = (doc["ref"], results)
|
| 199 |
+
return {
|
| 200 |
+
"bleu": ref_pred,
|
| 201 |
+
"chrf": ref_pred,
|
| 202 |
+
"ter": ref_pred,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
def aggregation(self):
|
| 206 |
+
"""
|
| 207 |
+
:returns: {str: [float] -> float}
|
| 208 |
+
A dictionary where keys are the names of submetrics and values are
|
| 209 |
+
functions that aggregate a list of metrics
|
| 210 |
+
"""
|
| 211 |
+
return {
|
| 212 |
+
"bleu": metrics.bleu,
|
| 213 |
+
"chrf": metrics.chrf,
|
| 214 |
+
"ter": metrics.ter,
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
def higher_is_better(self):
|
| 218 |
+
"""
|
| 219 |
+
:returns: {str: bool}
|
| 220 |
+
A dictionary where keys are the names of submetrics and values are
|
| 221 |
+
whether a higher value of the submetric is better
|
| 222 |
+
"""
|
| 223 |
+
return {
|
| 224 |
+
"bleu": True,
|
| 225 |
+
"chrf": True,
|
| 226 |
+
"ter": False,
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
def __str__(self):
|
| 230 |
+
language_codes = self.sacrebleu_language_pair.split("-")
|
| 231 |
+
src_lang = code_to_language(language_codes[0])
|
| 232 |
+
tar_lang = code_to_language(language_codes[1])
|
| 233 |
+
return f"{self.sacrebleu_dataset.upper()} {src_lang} to {tar_lang} Task"
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
########################################
|
| 237 |
+
# Util
|
| 238 |
+
########################################
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def code_to_language(code):
|
| 242 |
+
# key is alpha_2 or alpha_3 depending on the code length
|
| 243 |
+
language_tuple = pycountry.languages.get(**{f"alpha_{len(code)}": code})
|
| 244 |
+
return language_tuple.name
|
scripts/yans/eval/lm-evaluation-harness/lm_eval/utils.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import re
|
| 4 |
+
import collections
|
| 5 |
+
import functools
|
| 6 |
+
import inspect
|
| 7 |
+
import sys
|
| 8 |
+
from typing import List, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
|
| 14 |
+
import sacrebleu
|
| 15 |
+
from rouge_score import rouge_scorer, scoring
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ExitCodeError(Exception):
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def sh(x):
|
| 23 |
+
if os.system(x):
|
| 24 |
+
raise ExitCodeError()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def simple_parse_args_string(args_string):
|
| 28 |
+
"""
|
| 29 |
+
Parses something like
|
| 30 |
+
args1=val1,arg2=val2
|
| 31 |
+
Into a dictionary
|
| 32 |
+
"""
|
| 33 |
+
args_string = args_string.strip()
|
| 34 |
+
if not args_string:
|
| 35 |
+
return {}
|
| 36 |
+
arg_list = args_string.split(",")
|
| 37 |
+
args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list))
|
| 38 |
+
return args_dict
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def join_iters(iters):
|
| 42 |
+
for iter in iters:
|
| 43 |
+
yield from iter
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def chunks(iter, n):
|
| 47 |
+
arr = []
|
| 48 |
+
for x in iter:
|
| 49 |
+
arr.append(x)
|
| 50 |
+
if len(arr) == n:
|
| 51 |
+
yield arr
|
| 52 |
+
arr = []
|
| 53 |
+
|
| 54 |
+
if arr:
|
| 55 |
+
yield arr
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def group(arr, fn):
|
| 59 |
+
res = collections.defaultdict(list)
|
| 60 |
+
|
| 61 |
+
for ob in arr:
|
| 62 |
+
res[fn(ob)].append(ob)
|
| 63 |
+
|
| 64 |
+
return list(res.values())
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def general_detokenize(string):
|
| 68 |
+
string = string.replace(" n't", "n't")
|
| 69 |
+
string = string.replace(" )", ")")
|
| 70 |
+
string = string.replace("( ", "(")
|
| 71 |
+
string = string.replace('" ', '"')
|
| 72 |
+
string = string.replace(' "', '"')
|
| 73 |
+
string = re.sub(r" (['.,])", r"\1", string)
|
| 74 |
+
return string
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
|
| 78 |
+
"""
|
| 79 |
+
- context_len allows for a rolling window context, allowing each prediction window to potentially
|
| 80 |
+
condition on some context
|
| 81 |
+
|
| 82 |
+
:param token_list: list
|
| 83 |
+
List of tokens to be PREDICTED
|
| 84 |
+
:param max_seq_len: int
|
| 85 |
+
max_seq_len of model (or max_seq_len we want to use)
|
| 86 |
+
:param context_len: int
|
| 87 |
+
Amount of desired token context for prediction. Needs to be at least 1.
|
| 88 |
+
:param prefix_token: token
|
| 89 |
+
Dummy token like <eos> so the first token has something to condition on
|
| 90 |
+
:return: generator
|
| 91 |
+
Generator of tuples
|
| 92 |
+
(input_tokens, pred_tokens)
|
| 93 |
+
Note: Score only the last len(pred_tokens) logits of the LM
|
| 94 |
+
"""
|
| 95 |
+
assert 1 <= context_len <= max_seq_len
|
| 96 |
+
if not token_list:
|
| 97 |
+
return
|
| 98 |
+
# +1 offset, going from input->preds
|
| 99 |
+
pred_len = max_seq_len - context_len + 1
|
| 100 |
+
predicted = 0
|
| 101 |
+
|
| 102 |
+
# Special handling for first window: predict all tokens
|
| 103 |
+
first_seq_len = min(max_seq_len, len(token_list))
|
| 104 |
+
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
|
| 105 |
+
predicted += first_seq_len
|
| 106 |
+
|
| 107 |
+
while predicted < len(token_list):
|
| 108 |
+
window_pred_len = min(len(token_list) - predicted, pred_len)
|
| 109 |
+
window_end = predicted + window_pred_len
|
| 110 |
+
|
| 111 |
+
yield (
|
| 112 |
+
token_list[window_end - max_seq_len - 1 : window_end - 1],
|
| 113 |
+
token_list[window_end - window_pred_len : window_end],
|
| 114 |
+
)
|
| 115 |
+
predicted += window_pred_len
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def make_disjoint_window(pair):
|
| 119 |
+
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
|
| 120 |
+
a, b = pair
|
| 121 |
+
return a[: len(a) - (len(b) - 1)], b
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def select_continuation_from_batch_left_padding(
|
| 125 |
+
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
|
| 126 |
+
):
|
| 127 |
+
"""Select the continuation from the batch, removing prompts of different lengths.
|
| 128 |
+
Args:
|
| 129 |
+
generations (Union[List[List[int]], torch.Tensor]):
|
| 130 |
+
A tensor or list-of-lists of shape [batch_size, sequence length].
|
| 131 |
+
max_context_size (int):
|
| 132 |
+
The size of the biggest context; generations will proceed from that
|
| 133 |
+
index.
|
| 134 |
+
Example:
|
| 135 |
+
PAD PAD Continue : The dog chased the cat [every day of the week]
|
| 136 |
+
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
|
| 137 |
+
Output:
|
| 138 |
+
[every day of the week]
|
| 139 |
+
[yesterday] PAD PAD PAD PAD
|
| 140 |
+
"""
|
| 141 |
+
return generations[:, max_context_size:]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Reorderer:
|
| 145 |
+
def __init__(self, arr, fn):
|
| 146 |
+
self.size = len(arr)
|
| 147 |
+
arr = list(enumerate(arr))
|
| 148 |
+
arr = group(arr, lambda x: fn(x[1]))
|
| 149 |
+
arr = [([y[0] for y in x], x[0][1]) for x in arr]
|
| 150 |
+
arr.sort(key=lambda x: fn(x[1]))
|
| 151 |
+
|
| 152 |
+
self.arr = arr
|
| 153 |
+
|
| 154 |
+
def get_reordered(self):
|
| 155 |
+
return [x[1] for x in self.arr]
|
| 156 |
+
|
| 157 |
+
def get_original(self, newarr):
|
| 158 |
+
res = [None] * self.size
|
| 159 |
+
cov = [False] * self.size
|
| 160 |
+
|
| 161 |
+
for (inds, _), v in zip(self.arr, newarr):
|
| 162 |
+
for ind in inds:
|
| 163 |
+
res[ind] = v
|
| 164 |
+
cov[ind] = True
|
| 165 |
+
|
| 166 |
+
assert all(cov)
|
| 167 |
+
|
| 168 |
+
return res
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def positional_deprecated(fn):
|
| 172 |
+
"""
|
| 173 |
+
A decorator to nudge users into passing only keyword args (`kwargs`) to the
|
| 174 |
+
wrapped function, `fn`.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
@functools.wraps(fn)
|
| 178 |
+
def _wrapper(*args, **kwargs):
|
| 179 |
+
if len(args) != 1 if inspect.ismethod(fn) else 0:
|
| 180 |
+
print(
|
| 181 |
+
f"WARNING: using {fn.__name__} with positional arguments is "
|
| 182 |
+
"deprecated and will be disallowed in a future version of "
|
| 183 |
+
"lm-evaluation-harness!"
|
| 184 |
+
)
|
| 185 |
+
return fn(*args, **kwargs)
|
| 186 |
+
|
| 187 |
+
return _wrapper
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
@positional_deprecated
|
| 191 |
+
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
|
| 192 |
+
"""
|
| 193 |
+
Search upward in the directory tree to a maximum of three layers
|
| 194 |
+
to find and return the package root (containing the 'tests' folder)
|
| 195 |
+
"""
|
| 196 |
+
cur_path = start_path.resolve()
|
| 197 |
+
max_layers = 3
|
| 198 |
+
for _ in range(max_layers):
|
| 199 |
+
if (cur_path / "tests" / "test_version_stable.py").exists():
|
| 200 |
+
return cur_path
|
| 201 |
+
else:
|
| 202 |
+
cur_path = cur_path.parent.resolve()
|
| 203 |
+
raise FileNotFoundError(
|
| 204 |
+
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@positional_deprecated
|
| 209 |
+
def run_task_tests(task_list: List[str]):
|
| 210 |
+
"""
|
| 211 |
+
Find the package root and run the tests for the given tasks
|
| 212 |
+
"""
|
| 213 |
+
import pytest
|
| 214 |
+
|
| 215 |
+
package_root = find_test_root(start_path=pathlib.Path(__file__))
|
| 216 |
+
task_string = " or ".join(task_list)
|
| 217 |
+
args = [
|
| 218 |
+
f"{package_root}/tests/test_version_stable.py",
|
| 219 |
+
f"--rootdir={package_root}",
|
| 220 |
+
"-k",
|
| 221 |
+
f"{task_string}",
|
| 222 |
+
]
|
| 223 |
+
sys.path.append(str(package_root))
|
| 224 |
+
pytest_return_val = pytest.main(args)
|
| 225 |
+
if pytest_return_val:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def bleu(refs, preds):
|
| 232 |
+
"""
|
| 233 |
+
Returns `t5` style BLEU scores. See the related implementation:
|
| 234 |
+
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L41
|
| 235 |
+
|
| 236 |
+
:param refs:
|
| 237 |
+
A `list` of `list` of reference `str`s.
|
| 238 |
+
:param preds:
|
| 239 |
+
A `list` of predicted `str`s.
|
| 240 |
+
"""
|
| 241 |
+
score = sacrebleu.corpus_bleu(
|
| 242 |
+
preds,
|
| 243 |
+
refs,
|
| 244 |
+
smooth_method="exp",
|
| 245 |
+
smooth_value=0.0,
|
| 246 |
+
force=False,
|
| 247 |
+
lowercase=False,
|
| 248 |
+
tokenize="intl",
|
| 249 |
+
use_effective_order=False,
|
| 250 |
+
).score
|
| 251 |
+
return score
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def rouge(refs, preds):
|
| 255 |
+
"""
|
| 256 |
+
Returns `t5` style ROUGE scores. See the related implementation:
|
| 257 |
+
https://github.com/google-research/text-to-text-transfer-transformer/blob/3d10afd51ba97ac29eb66ae701eca274488202f7/t5/evaluation/metrics.py#L68
|
| 258 |
+
|
| 259 |
+
:param refs:
|
| 260 |
+
A `list` of reference `strs`.
|
| 261 |
+
:param preds:
|
| 262 |
+
A `list` of predicted `strs`.
|
| 263 |
+
"""
|
| 264 |
+
rouge_types = ["rouge1", "rouge2", "rougeLsum"]
|
| 265 |
+
scorer = rouge_scorer.RougeScorer(rouge_types)
|
| 266 |
+
# Add newlines between sentences to correctly compute `rougeLsum`.
|
| 267 |
+
|
| 268 |
+
def _prepare_summary(summary):
|
| 269 |
+
summary = summary.replace(" . ", ".\n")
|
| 270 |
+
return summary
|
| 271 |
+
|
| 272 |
+
# Accumulate confidence intervals.
|
| 273 |
+
aggregator = scoring.BootstrapAggregator()
|
| 274 |
+
for ref, pred in zip(refs, preds):
|
| 275 |
+
ref = _prepare_summary(ref)
|
| 276 |
+
pred = _prepare_summary(pred)
|
| 277 |
+
aggregator.add_scores(scorer.score(ref, pred))
|
| 278 |
+
result = aggregator.aggregate()
|
| 279 |
+
return {type: result[type].mid.fmeasure * 100 for type in rouge_types}
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def rouge2_mecab(refs, preds, tokenizer):
|
| 283 |
+
"""This uses a MeCab tokenizer for Japanese text.
|
| 284 |
+
|
| 285 |
+
Besides specifying the tokenizer, this does not perform the rougeLsum
|
| 286 |
+
related sentence/newline normalization, and only calculates rouge2.
|
| 287 |
+
Otherwise it is the same as the generic rouge scoring.
|
| 288 |
+
"""
|
| 289 |
+
rouge_types = ["rouge2"]
|
| 290 |
+
# mecab-based rouge
|
| 291 |
+
scorer = rouge_scorer.RougeScorer(
|
| 292 |
+
rouge_types,
|
| 293 |
+
tokenizer=tokenizer,
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Accumulate confidence intervals.
|
| 297 |
+
aggregator = scoring.BootstrapAggregator()
|
| 298 |
+
for ref, pred in zip(refs, preds):
|
| 299 |
+
aggregator.add_scores(scorer.score(ref, pred))
|
| 300 |
+
result = aggregator.aggregate()
|
| 301 |
+
return {type: result[type].mid.fmeasure * 100 for type in rouge_types}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/harness.jsquad-1.2.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-1b,device_map=auto,torch_dtype=auto"
|
| 2 |
+
TASK="jsquad-1.2-0.2"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent-open-calm-1b/result.jsquad-1.2.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-1b"
|
| 2 |
+
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,xlsum_ja"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3,1" --device "cuda" --output_path "models/cyberagent-open-calm-1b/result.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/result.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jcommonsenseqa-1.1-0.2": {
|
| 4 |
+
"acc": 0.26899016979445933,
|
| 5 |
+
"acc_stderr": 0.013261996572328063,
|
| 6 |
+
"acc_norm": 0.24754244861483468,
|
| 7 |
+
"acc_norm_stderr": 0.01290758346346734
|
| 8 |
+
},
|
| 9 |
+
"jnli-1.1-0.2": {
|
| 10 |
+
"acc": 0.33566146261298274,
|
| 11 |
+
"acc_stderr": 0.00957358086224245,
|
| 12 |
+
"acc_norm": 0.3331963845521775,
|
| 13 |
+
"acc_norm_stderr": 0.009556042193601356
|
| 14 |
+
},
|
| 15 |
+
"marc_ja-1.1-0.2": {
|
| 16 |
+
"acc": 0.7792117195674921,
|
| 17 |
+
"acc_stderr": 0.005478034657719626,
|
| 18 |
+
"acc_norm": 0.7792117195674921,
|
| 19 |
+
"acc_norm_stderr": 0.005478034657719626
|
| 20 |
+
},
|
| 21 |
+
"jsquad-1.1-0.2": {
|
| 22 |
+
"exact_match": 37.12291760468258,
|
| 23 |
+
"f1": 47.171446643186265
|
| 24 |
+
},
|
| 25 |
+
"xlsum_ja": {
|
| 26 |
+
"rouge2": 2.288077088085482
|
| 27 |
+
},
|
| 28 |
+
"xwinograd_ja": {
|
| 29 |
+
"acc": 0.6089676746611054,
|
| 30 |
+
"acc_stderr": 0.015765969995357912
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"versions": {
|
| 34 |
+
"jcommonsenseqa-1.1-0.2": 1.1,
|
| 35 |
+
"jnli-1.1-0.2": 1.1,
|
| 36 |
+
"jsquad-1.1-0.2": 1.1,
|
| 37 |
+
"marc_ja-1.1-0.2": 1.1,
|
| 38 |
+
"xlsum_ja": 1.0,
|
| 39 |
+
"xwinograd_ja": 1.0
|
| 40 |
+
},
|
| 41 |
+
"config": {
|
| 42 |
+
"model": "hf-causal",
|
| 43 |
+
"model_args": "pretrained=cyberagent/open-calm-1b",
|
| 44 |
+
"num_fewshot": [
|
| 45 |
+
2,
|
| 46 |
+
3,
|
| 47 |
+
3,
|
| 48 |
+
3,
|
| 49 |
+
1,
|
| 50 |
+
0
|
| 51 |
+
],
|
| 52 |
+
"batch_size": null,
|
| 53 |
+
"device": "cuda",
|
| 54 |
+
"no_cache": false,
|
| 55 |
+
"limit": null,
|
| 56 |
+
"bootstrap_iters": 100000,
|
| 57 |
+
"description_dict": {}
|
| 58 |
+
}
|
| 59 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/result.jsquad-1.2.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jsquad-1.2-0.2": {
|
| 4 |
+
"exact_match": 39.53174245835209,
|
| 5 |
+
"f1": 49.49399460234075
|
| 6 |
+
}
|
| 7 |
+
},
|
| 8 |
+
"versions": {
|
| 9 |
+
"jsquad-1.2-0.2": 1.2
|
| 10 |
+
},
|
| 11 |
+
"config": {
|
| 12 |
+
"model": "hf-causal",
|
| 13 |
+
"model_args": "pretrained=cyberagent/open-calm-1b",
|
| 14 |
+
"num_fewshot": 3,
|
| 15 |
+
"batch_size": null,
|
| 16 |
+
"device": "cuda",
|
| 17 |
+
"no_cache": false,
|
| 18 |
+
"limit": null,
|
| 19 |
+
"bootstrap_iters": 100000,
|
| 20 |
+
"description_dict": {}
|
| 21 |
+
}
|
| 22 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-1b/result.mgsm.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/harness.jsquad-1.2.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto"
|
| 2 |
+
TASK="jsquad-1.2-0.2"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-3b/result.jsquad-1.2.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-3b"
|
| 2 |
+
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,jaqket_v2-0.1-0.2,xlsum_ja,xwinograd_ja,mgsm"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3,3,3,2,1,1,0,5" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-3b/result.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/result.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jcommonsenseqa-1.1-0.2": {
|
| 4 |
+
"acc": 0.2779267202859696,
|
| 5 |
+
"acc_stderr": 0.013397843071173697,
|
| 6 |
+
"acc_norm": 0.2529043789097408,
|
| 7 |
+
"acc_norm_stderr": 0.013000060342436679
|
| 8 |
+
},
|
| 9 |
+
"jnli-1.1-0.2": {
|
| 10 |
+
"acc": 0.40345110928512734,
|
| 11 |
+
"acc_stderr": 0.009945976384444125,
|
| 12 |
+
"acc_norm": 0.37674609695973704,
|
| 13 |
+
"acc_norm_stderr": 0.009823942907406487
|
| 14 |
+
},
|
| 15 |
+
"marc_ja-1.1-0.2": {
|
| 16 |
+
"acc": 0.8620509243111266,
|
| 17 |
+
"acc_stderr": 0.004554438976572761,
|
| 18 |
+
"acc_norm": 0.8620509243111266,
|
| 19 |
+
"acc_norm_stderr": 0.004554438976572761
|
| 20 |
+
},
|
| 21 |
+
"xwinograd_ja": {
|
| 22 |
+
"acc": 0.6360792492179353,
|
| 23 |
+
"acc_stderr": 0.015544482535576241
|
| 24 |
+
},
|
| 25 |
+
"jsquad-1.1-0.2": {
|
| 26 |
+
"exact_match": 40.45475011256191,
|
| 27 |
+
"f1": 52.73709875917724
|
| 28 |
+
},
|
| 29 |
+
"jaqket_v2-0.1-0.2": {
|
| 30 |
+
"exact_match": 46.90721649484536,
|
| 31 |
+
"f1": 51.615597556319194
|
| 32 |
+
},
|
| 33 |
+
"xlsum_ja": {
|
| 34 |
+
"rouge2": 1.948450071736146
|
| 35 |
+
},
|
| 36 |
+
"mgsm": {
|
| 37 |
+
"acc": 0.016,
|
| 38 |
+
"acc_stderr": 0.007951661188874344
|
| 39 |
+
}
|
| 40 |
+
},
|
| 41 |
+
"versions": {
|
| 42 |
+
"jcommonsenseqa-1.1-0.2": 1.1,
|
| 43 |
+
"jnli-1.1-0.2": 1.1,
|
| 44 |
+
"marc_ja-1.1-0.2": 1.1,
|
| 45 |
+
"jsquad-1.1-0.2": 1.1,
|
| 46 |
+
"jaqket_v2-0.1-0.2": 0.1,
|
| 47 |
+
"xlsum_ja": 1.0,
|
| 48 |
+
"xwinograd_ja": 1.0,
|
| 49 |
+
"mgsm": 1.0
|
| 50 |
+
},
|
| 51 |
+
"config": {
|
| 52 |
+
"model": "hf-causal",
|
| 53 |
+
"model_args": "pretrained=cyberagent/open-calm-3b",
|
| 54 |
+
"num_fewshot": [
|
| 55 |
+
3,
|
| 56 |
+
3,
|
| 57 |
+
3,
|
| 58 |
+
2,
|
| 59 |
+
1,
|
| 60 |
+
1,
|
| 61 |
+
0,
|
| 62 |
+
5
|
| 63 |
+
],
|
| 64 |
+
"batch_size": null,
|
| 65 |
+
"device": "cuda",
|
| 66 |
+
"no_cache": false,
|
| 67 |
+
"limit": null,
|
| 68 |
+
"bootstrap_iters": 100000,
|
| 69 |
+
"description_dict": {}
|
| 70 |
+
}
|
| 71 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/result.jsquad-1.2.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jsquad-1.2-0.2": {
|
| 4 |
+
"exact_match": 44.529491220171096,
|
| 5 |
+
"f1": 56.02141036867636
|
| 6 |
+
}
|
| 7 |
+
},
|
| 8 |
+
"versions": {
|
| 9 |
+
"jsquad-1.2-0.2": 1.2
|
| 10 |
+
},
|
| 11 |
+
"config": {
|
| 12 |
+
"model": "hf-causal",
|
| 13 |
+
"model_args": "pretrained=cyberagent/open-calm-3b,device_map=auto,torch_dtype=auto",
|
| 14 |
+
"num_fewshot": 2,
|
| 15 |
+
"batch_size": null,
|
| 16 |
+
"device": "cuda",
|
| 17 |
+
"no_cache": false,
|
| 18 |
+
"limit": null,
|
| 19 |
+
"bootstrap_iters": 100000,
|
| 20 |
+
"description_dict": {}
|
| 21 |
+
}
|
| 22 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-3b/result.mgsm.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/harness.jsquad-1.2.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-7b,device_map=auto,torch_dtype=auto"
|
| 2 |
+
TASK="jsquad-1.2-0.2"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-7b/result.jsquad-1.2.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-7b"
|
| 2 |
+
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,jaqket_v2-0.1-0.2,xlsum_ja,xwinograd_ja,mgsm"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3,3,3,2,1,1,0,5" --device "cuda" --output_path "models/cyberagent/cyberagent-open-calm-7b/result.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/result.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jcommonsenseqa-1.1-0.2": {
|
| 4 |
+
"acc": 0.2421805183199285,
|
| 5 |
+
"acc_stderr": 0.012812432289317907,
|
| 6 |
+
"acc_norm": 0.24396782841823056,
|
| 7 |
+
"acc_norm_stderr": 0.012844450125623429
|
| 8 |
+
},
|
| 9 |
+
"jnli-1.1-0.2": {
|
| 10 |
+
"acc": 0.3763352506162695,
|
| 11 |
+
"acc_stderr": 0.00982182053150895,
|
| 12 |
+
"acc_norm": 0.3463434675431389,
|
| 13 |
+
"acc_norm_stderr": 0.009646221914241809
|
| 14 |
+
},
|
| 15 |
+
"marc_ja-1.1-0.2": {
|
| 16 |
+
"acc": 0.7411928845483083,
|
| 17 |
+
"acc_stderr": 0.005784459117732042,
|
| 18 |
+
"acc_norm": 0.7411928845483083,
|
| 19 |
+
"acc_norm_stderr": 0.005784459117732042
|
| 20 |
+
},
|
| 21 |
+
"xwinograd_ja": {
|
| 22 |
+
"acc": 0.6506777893639207,
|
| 23 |
+
"acc_stderr": 0.01540328448938605
|
| 24 |
+
},
|
| 25 |
+
"jsquad-1.1-0.2": {
|
| 26 |
+
"exact_match": 45.79018460153084,
|
| 27 |
+
"f1": 59.03158509144496
|
| 28 |
+
},
|
| 29 |
+
"jaqket_v2-0.1-0.2": {
|
| 30 |
+
"exact_match": 60.738831615120276,
|
| 31 |
+
"f1": 64.89929362352039
|
| 32 |
+
},
|
| 33 |
+
"xlsum_ja": {
|
| 34 |
+
"rouge2": 2.0382422339290223
|
| 35 |
+
},
|
| 36 |
+
"mgsm": {
|
| 37 |
+
"acc": 0.008,
|
| 38 |
+
"acc_stderr": 0.005645483676690164
|
| 39 |
+
}
|
| 40 |
+
},
|
| 41 |
+
"versions": {
|
| 42 |
+
"jcommonsenseqa-1.1-0.2": 1.1,
|
| 43 |
+
"jnli-1.1-0.2": 1.1,
|
| 44 |
+
"marc_ja-1.1-0.2": 1.1,
|
| 45 |
+
"jsquad-1.1-0.2": 1.1,
|
| 46 |
+
"jaqket_v2-0.1-0.2": 0.1,
|
| 47 |
+
"xlsum_ja": 1.0,
|
| 48 |
+
"xwinograd_ja": 1.0,
|
| 49 |
+
"mgsm": 1.0
|
| 50 |
+
},
|
| 51 |
+
"config": {
|
| 52 |
+
"model": "hf-causal",
|
| 53 |
+
"model_args": "pretrained=cyberagent/open-calm-7b",
|
| 54 |
+
"num_fewshot": [
|
| 55 |
+
3,
|
| 56 |
+
3,
|
| 57 |
+
3,
|
| 58 |
+
2,
|
| 59 |
+
1,
|
| 60 |
+
1,
|
| 61 |
+
0,
|
| 62 |
+
5
|
| 63 |
+
],
|
| 64 |
+
"batch_size": null,
|
| 65 |
+
"device": "cuda",
|
| 66 |
+
"no_cache": false,
|
| 67 |
+
"limit": null,
|
| 68 |
+
"bootstrap_iters": 100000,
|
| 69 |
+
"description_dict": {}
|
| 70 |
+
}
|
| 71 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/result.jsquad-1.2.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jsquad-1.2-0.2": {
|
| 4 |
+
"exact_match": 48.10895992796038,
|
| 5 |
+
"f1": 60.90961937230767
|
| 6 |
+
}
|
| 7 |
+
},
|
| 8 |
+
"versions": {
|
| 9 |
+
"jsquad-1.2-0.2": 1.2
|
| 10 |
+
},
|
| 11 |
+
"config": {
|
| 12 |
+
"model": "hf-causal",
|
| 13 |
+
"model_args": "pretrained=cyberagent/open-calm-7b,device_map=auto,torch_dtype=auto",
|
| 14 |
+
"num_fewshot": 2,
|
| 15 |
+
"batch_size": null,
|
| 16 |
+
"device": "cuda",
|
| 17 |
+
"no_cache": false,
|
| 18 |
+
"limit": null,
|
| 19 |
+
"bootstrap_iters": 100000,
|
| 20 |
+
"description_dict": {}
|
| 21 |
+
}
|
| 22 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-7b/result.mgsm.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/harness.jsquad-1.2.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-large,use_fast=True,device_map=auto,torch_dtype=auto"
|
| 2 |
+
TASK="jsquad-1.2-0.2"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent-open-calm-large/result.jsquad-1.2.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-large,use_fast=True"
|
| 2 |
+
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,xlsum_ja"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3,1" --device "cuda" --output_path "models/cyberagent-open-calm-large/result.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/result.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jcommonsenseqa-1.1-0.2": {
|
| 4 |
+
"acc": 0.2993744414655943,
|
| 5 |
+
"acc_stderr": 0.013697125864334919,
|
| 6 |
+
"acc_norm": 0.2752457551385165,
|
| 7 |
+
"acc_norm_stderr": 0.013357795705028184
|
| 8 |
+
},
|
| 9 |
+
"jnli-1.1-0.2": {
|
| 10 |
+
"acc": 0.40838126540673786,
|
| 11 |
+
"acc_stderr": 0.009965126356916034,
|
| 12 |
+
"acc_norm": 0.3751027115858669,
|
| 13 |
+
"acc_norm_stderr": 0.009815408241248635
|
| 14 |
+
},
|
| 15 |
+
"marc_ja-1.1-0.2": {
|
| 16 |
+
"acc": 0.7912452040460412,
|
| 17 |
+
"acc_stderr": 0.005367632889806105,
|
| 18 |
+
"acc_norm": 0.7912452040460412,
|
| 19 |
+
"acc_norm_stderr": 0.005367632889806105
|
| 20 |
+
},
|
| 21 |
+
"jsquad-1.1-0.2": {
|
| 22 |
+
"exact_match": 37.23547951373255,
|
| 23 |
+
"f1": 48.50349592141573
|
| 24 |
+
},
|
| 25 |
+
"xlsum_ja": {
|
| 26 |
+
"rouge2": 1.9854375467671679
|
| 27 |
+
},
|
| 28 |
+
"xwinograd_ja": {
|
| 29 |
+
"acc": 0.6152241918665277,
|
| 30 |
+
"acc_stderr": 0.015719467393137274
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"versions": {
|
| 34 |
+
"jcommonsenseqa-1.1-0.2": 1.1,
|
| 35 |
+
"jnli-1.1-0.2": 1.1,
|
| 36 |
+
"jsquad-1.1-0.2": 1.1,
|
| 37 |
+
"marc_ja-1.1-0.2": 1.1,
|
| 38 |
+
"xlsum_ja": 1.0,
|
| 39 |
+
"xwinograd_ja": 1.0
|
| 40 |
+
},
|
| 41 |
+
"config": {
|
| 42 |
+
"model": "hf-causal",
|
| 43 |
+
"model_args": "pretrained=cyberagent/open-calm-large,use_fast=True",
|
| 44 |
+
"num_fewshot": [
|
| 45 |
+
2,
|
| 46 |
+
3,
|
| 47 |
+
3,
|
| 48 |
+
3,
|
| 49 |
+
1,
|
| 50 |
+
0
|
| 51 |
+
],
|
| 52 |
+
"batch_size": null,
|
| 53 |
+
"device": "cuda",
|
| 54 |
+
"no_cache": false,
|
| 55 |
+
"limit": null,
|
| 56 |
+
"bootstrap_iters": 100000,
|
| 57 |
+
"description_dict": {}
|
| 58 |
+
}
|
| 59 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-large/result.jsquad-1.2.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jsquad-1.2-0.2": {
|
| 4 |
+
"exact_match": 40.4997748761819,
|
| 5 |
+
"f1": 51.32160467436942
|
| 6 |
+
}
|
| 7 |
+
},
|
| 8 |
+
"versions": {
|
| 9 |
+
"jsquad-1.2-0.2": 1.2
|
| 10 |
+
},
|
| 11 |
+
"config": {
|
| 12 |
+
"model": "hf-causal",
|
| 13 |
+
"model_args": "pretrained=cyberagent/open-calm-large,use_fast=True,device_map=auto,torch_dtype=auto",
|
| 14 |
+
"num_fewshot": 3,
|
| 15 |
+
"batch_size": null,
|
| 16 |
+
"device": "cuda",
|
| 17 |
+
"no_cache": false,
|
| 18 |
+
"limit": null,
|
| 19 |
+
"bootstrap_iters": 100000,
|
| 20 |
+
"description_dict": {}
|
| 21 |
+
}
|
| 22 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/harness.jsquad-1.2.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-medium,use_fast=True,device_map=auto,torch_dtype=auto"
|
| 2 |
+
TASK="jsquad-1.2-0.2"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "3" --device "cuda" --output_path "models/cyberagent-open-calm-medium/result.jsquad-1.2.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=cyberagent/open-calm-medium,use_fast=True"
|
| 2 |
+
TASK="jcommonsenseqa-1.1-0.2,jnli-1.1-0.2,marc_ja-1.1-0.2,jsquad-1.1-0.2,xlsum_ja"
|
| 3 |
+
python main.py --model hf-causal --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3,1" --device "cuda" --output_path "models/cyberagent-open-calm-medium/result.json"
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/result.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jcommonsenseqa-1.1-0.2": {
|
| 4 |
+
"acc": 0.39499553172475427,
|
| 5 |
+
"acc_stderr": 0.0146202392872941,
|
| 6 |
+
"acc_norm": 0.2868632707774799,
|
| 7 |
+
"acc_norm_stderr": 0.013527046208250626
|
| 8 |
+
},
|
| 9 |
+
"jnli-1.1-0.2": {
|
| 10 |
+
"acc": 0.4231717337715694,
|
| 11 |
+
"acc_stderr": 0.010016374130527417,
|
| 12 |
+
"acc_norm": 0.3972884141331142,
|
| 13 |
+
"acc_norm_stderr": 0.009920570907906705
|
| 14 |
+
},
|
| 15 |
+
"marc_ja-1.1-0.2": {
|
| 16 |
+
"acc": 0.8357167771189397,
|
| 17 |
+
"acc_stderr": 0.004893675823612713,
|
| 18 |
+
"acc_norm": 0.8357167771189397,
|
| 19 |
+
"acc_norm_stderr": 0.004893675823612713
|
| 20 |
+
},
|
| 21 |
+
"jsquad-1.1-0.2": {
|
| 22 |
+
"exact_match": 28.725799189554255,
|
| 23 |
+
"f1": 39.80333448254385
|
| 24 |
+
},
|
| 25 |
+
"xlsum_ja": {
|
| 26 |
+
"rouge2": 2.5775988917922406
|
| 27 |
+
},
|
| 28 |
+
"xwinograd_ja": {
|
| 29 |
+
"acc": 0.5964546402502607,
|
| 30 |
+
"acc_stderr": 0.015850834635341565
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"versions": {
|
| 34 |
+
"jcommonsenseqa-1.1-0.2": 1.1,
|
| 35 |
+
"jnli-1.1-0.2": 1.1,
|
| 36 |
+
"jsquad-1.1-0.2": 1.1,
|
| 37 |
+
"marc_ja-1.1-0.2": 1.1,
|
| 38 |
+
"xlsum_ja": 1.0,
|
| 39 |
+
"xwinograd_ja": 1.0
|
| 40 |
+
},
|
| 41 |
+
"config": {
|
| 42 |
+
"model": "hf-causal",
|
| 43 |
+
"model_args": "pretrained=cyberagent/open-calm-medium,use_fast=True",
|
| 44 |
+
"num_fewshot": [
|
| 45 |
+
2,
|
| 46 |
+
3,
|
| 47 |
+
3,
|
| 48 |
+
3,
|
| 49 |
+
1,
|
| 50 |
+
0
|
| 51 |
+
],
|
| 52 |
+
"batch_size": null,
|
| 53 |
+
"device": "cuda",
|
| 54 |
+
"no_cache": false,
|
| 55 |
+
"limit": null,
|
| 56 |
+
"bootstrap_iters": 100000,
|
| 57 |
+
"description_dict": {}
|
| 58 |
+
}
|
| 59 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/cyberagent/cyberagent-open-calm-medium/result.jsquad-1.2.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jsquad-1.2-0.2": {
|
| 4 |
+
"exact_match": 29.85141828005403,
|
| 5 |
+
"f1": 40.49655778214922
|
| 6 |
+
}
|
| 7 |
+
},
|
| 8 |
+
"versions": {
|
| 9 |
+
"jsquad-1.2-0.2": 1.2
|
| 10 |
+
},
|
| 11 |
+
"config": {
|
| 12 |
+
"model": "hf-causal",
|
| 13 |
+
"model_args": "pretrained=cyberagent/open-calm-medium,use_fast=True,device_map=auto,torch_dtype=auto",
|
| 14 |
+
"num_fewshot": 3,
|
| 15 |
+
"batch_size": null,
|
| 16 |
+
"device": "cuda",
|
| 17 |
+
"no_cache": false,
|
| 18 |
+
"limit": null,
|
| 19 |
+
"bootstrap_iters": 100000,
|
| 20 |
+
"description_dict": {}
|
| 21 |
+
}
|
| 22 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/llama/llama-13b/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=huggyllama/llama-13b,use_accelerate=True,load_in_8bit=True"
|
| 2 |
+
TASK="jsquad-1.1-0.3,jcommonsenseqa-1.1-0.3,jnli-1.1-0.3,marc_ja-1.1-0.3"
|
| 3 |
+
python main.py --model hf-causal-experimental --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3" --device "cuda" --output_path "models/llama/llama-13b/result.json" --batch_size 2 > models/llama/llama-13b/harness.out 2> models/llama/llama-13b/harness.err
|
scripts/yans/eval/lm-evaluation-harness/models/llama/llama-13b/result.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"results": {
|
| 3 |
+
"jsquad-1.1-0.3": {
|
| 4 |
+
"exact_match": 51.05808194506979,
|
| 5 |
+
"f1": 65.19689339101781
|
| 6 |
+
},
|
| 7 |
+
"jcommonsenseqa-1.1-0.3": {
|
| 8 |
+
"acc": 0.4932975871313673,
|
| 9 |
+
"acc_stderr": 0.014952371541808172,
|
| 10 |
+
"acc_norm": 0.29848078641644327,
|
| 11 |
+
"acc_norm_stderr": 0.013685386698397504
|
| 12 |
+
},
|
| 13 |
+
"jnli-1.1-0.3": {
|
| 14 |
+
"acc": 0.24116680361544782,
|
| 15 |
+
"acc_stderr": 0.008672830725110452,
|
| 16 |
+
"acc_norm": 0.30156121610517667,
|
| 17 |
+
"acc_norm_stderr": 0.009304239098715018
|
| 18 |
+
},
|
| 19 |
+
"marc_ja-1.1-0.3": {
|
| 20 |
+
"acc": 0.8791419602371817,
|
| 21 |
+
"acc_stderr": 0.004305031232204757,
|
| 22 |
+
"acc_norm": 0.8791419602371817,
|
| 23 |
+
"acc_norm_stderr": 0.004305031232204757
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"versions": {
|
| 27 |
+
"jsquad-1.1-0.3": 1.1,
|
| 28 |
+
"jcommonsenseqa-1.1-0.3": 1.1,
|
| 29 |
+
"jnli-1.1-0.3": 1.1,
|
| 30 |
+
"marc_ja-1.1-0.3": 1.1
|
| 31 |
+
},
|
| 32 |
+
"config": {
|
| 33 |
+
"model": "hf-causal-experimental",
|
| 34 |
+
"model_args": "pretrained=huggyllama/llama-13b,use_accelerate=True,load_in_8bit=True",
|
| 35 |
+
"num_fewshot": [
|
| 36 |
+
2,
|
| 37 |
+
3,
|
| 38 |
+
3,
|
| 39 |
+
3
|
| 40 |
+
],
|
| 41 |
+
"batch_size": 2,
|
| 42 |
+
"device": "cuda",
|
| 43 |
+
"no_cache": false,
|
| 44 |
+
"limit": null,
|
| 45 |
+
"bootstrap_iters": 100000,
|
| 46 |
+
"description_dict": {}
|
| 47 |
+
}
|
| 48 |
+
}
|
scripts/yans/eval/lm-evaluation-harness/models/llama/llama-30b/harness.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL_ARGS="pretrained=huggyllama/llama-30b,use_accelerate=True,load_in_8bit=True"
|
| 2 |
+
TASK="jsquad-1.1-0.3,jcommonsenseqa-1.1-0.3,jnli-1.1-0.3,marc_ja-1.1-0.3"
|
| 3 |
+
python main.py --model hf-causal-experimental --model_args $MODEL_ARGS --tasks $TASK --num_fewshot "2,3,3,3" --device "cuda" --output_path "models/llama/llama-30b/result.json" --batch_size 2 > models/llama/llama-30b/harness.out 2> models/llama/llama-30b/harness.err
|