| |
|
| | '''
|
| | This script fetches all the models used in the server tests.
|
| |
|
| | This is useful for slow tests that use larger models, to avoid them timing out on the model downloads.
|
| |
|
| | It is meant to be run from the root of the repository.
|
| |
|
| | Example:
|
| | python scripts/fetch_server_test_models.py
|
| | ( cd tools/server/tests && ./tests.sh -v -x -m slow )
|
| | '''
|
| | import ast
|
| | import glob
|
| | import logging
|
| | import os
|
| | from typing import Generator
|
| | from pydantic import BaseModel
|
| | from typing import Optional
|
| | import subprocess
|
| |
|
| |
|
| | class HuggingFaceModel(BaseModel):
|
| | hf_repo: str
|
| | hf_file: Optional[str] = None
|
| |
|
| | class Config:
|
| | frozen = True
|
| |
|
| |
|
| | def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]:
|
| | try:
|
| | with open(test_file) as f:
|
| | tree = ast.parse(f.read())
|
| | except Exception as e:
|
| | logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}')
|
| | return
|
| |
|
| | for node in ast.walk(tree):
|
| | if isinstance(node, ast.FunctionDef):
|
| | for dec in node.decorator_list:
|
| | if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize':
|
| | param_names = ast.literal_eval(dec.args[0]).split(",")
|
| | if "hf_repo" not in param_names:
|
| | continue
|
| |
|
| | raw_param_values = dec.args[1]
|
| | if not isinstance(raw_param_values, ast.List):
|
| | logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}')
|
| | continue
|
| |
|
| | hf_repo_idx = param_names.index("hf_repo")
|
| | hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None
|
| |
|
| | for t in raw_param_values.elts:
|
| | if not isinstance(t, ast.Tuple):
|
| | logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}')
|
| | continue
|
| | yield HuggingFaceModel(
|
| | hf_repo=ast.literal_eval(t.elts[hf_repo_idx]),
|
| | hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None)
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
| |
|
| | models = sorted(list(set([
|
| | model
|
| | for test_file in glob.glob('tools/server/tests/unit/test_*.py')
|
| | for model in collect_hf_model_test_parameters(test_file)
|
| | ])), key=lambda m: (m.hf_repo, m.hf_file))
|
| |
|
| | logging.info(f'Found {len(models)} models in parameterized tests:')
|
| | for m in models:
|
| | logging.info(f' - {m.hf_repo} / {m.hf_file}')
|
| |
|
| | cli_path = os.environ.get(
|
| | 'LLAMA_CLI_BIN_PATH',
|
| | os.path.join(
|
| | os.path.dirname(__file__),
|
| | '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli'))
|
| |
|
| | for m in models:
|
| | if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file):
|
| | continue
|
| | if m.hf_file is not None and '-of-' in m.hf_file:
|
| | logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file')
|
| | continue
|
| | logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched')
|
| | cmd = [
|
| | cli_path,
|
| | '-hfr', m.hf_repo,
|
| | *([] if m.hf_file is None else ['-hff', m.hf_file]),
|
| | '-n', '1',
|
| | '-p', 'Hey',
|
| | '--no-warmup',
|
| | '--log-disable',
|
| | '-no-cnv']
|
| | if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo:
|
| | cmd.append('-fa')
|
| | try:
|
| | subprocess.check_call(cmd)
|
| | except subprocess.CalledProcessError:
|
| | logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}')
|
| | exit(1)
|
| |
|