Spaces:
Paused
Paused
| import pkg_resources | |
| from pkg_resources import DistributionNotFound, VersionConflict | |
| from src.utils import remove | |
| from tests.utils import wrap_test_forked | |
| def get_all_requirements(): | |
| import glob | |
| requirements_all = [] | |
| reqs_http_all = [] | |
| for req_name in ['requirements.txt'] + glob.glob('reqs_optional/req*.txt'): | |
| requirements1, reqs_http1 = get_requirements(req_name) | |
| requirements_all.extend(requirements1) | |
| reqs_http_all.extend(reqs_http1) | |
| return requirements_all, reqs_http_all | |
| def get_requirements(req_file="requirements.txt"): | |
| req_tmp_file = req_file + '.tmp.txt' | |
| try: | |
| reqs_http = [] | |
| with open(req_file, 'rt') as f: | |
| contents = f.readlines() | |
| with open(req_tmp_file, 'wt') as g: | |
| for line in contents: | |
| if 'http://' not in line and 'https://' not in line: | |
| g.write(line) | |
| else: | |
| reqs_http.append(line.replace('\n', '')) | |
| reqs_http = [x for x in reqs_http if x] | |
| print('reqs_http: %s' % reqs_http, flush=True) | |
| with open(req_tmp_file, "rt") as f: | |
| requirements = pkg_resources.parse_requirements(f.read()) | |
| finally: | |
| remove(req_tmp_file) | |
| return requirements, reqs_http | |
| def test_requirements(): | |
| """Test that each required package is available.""" | |
| packages_all = [] | |
| packages_dist = [] | |
| packages_version = [] | |
| packages_unkn = [] | |
| requirements, reqs_http = get_all_requirements() | |
| for requirement in requirements: | |
| try: | |
| requirement = str(requirement) | |
| pkg_resources.require(requirement) | |
| except DistributionNotFound: | |
| packages_all.append(requirement) | |
| packages_dist.append(requirement) | |
| except VersionConflict: | |
| packages_all.append(requirement) | |
| packages_version.append(requirement) | |
| except pkg_resources.extern.packaging.requirements.InvalidRequirement: | |
| packages_all.append(requirement) | |
| packages_unkn.append(requirement) | |
| packages_all.extend(reqs_http) | |
| if packages_dist or packages_version: | |
| print('Missing packages: %s' % packages_dist, flush=True) | |
| print('Wrong version of packages: %s' % packages_version, flush=True) | |
| print("Can't determine (e.g. http) packages: %s" % packages_unkn, flush=True) | |
| print('\n\nRUN THIS:\n\n', flush=True) | |
| print( | |
| 'pip uninstall peft transformers accelerate -y ; CUDA_HOME=/usr/local/cuda-11.7 pip install %s --upgrade' % str( | |
| ' '.join(packages_all)), flush=True) | |
| print('\n\n', flush=True) | |
| raise ValueError(packages_all) | |
| import requests | |
| import json | |
| try: | |
| from packaging.version import parse | |
| except ImportError: | |
| from pip._vendor.packaging.version import parse | |
| URL_PATTERN = 'https://pypi.python.org/pypi/{package}/json' | |
| def get_version(package, url_pattern=URL_PATTERN): | |
| """Return version of package on pypi.python.org using json.""" | |
| req = requests.get(url_pattern.format(package=package)) | |
| version = parse('0') | |
| if req.status_code == requests.codes.ok: | |
| j = json.loads(req.text.encode(req.encoding)) | |
| releases = j.get('releases', []) | |
| for release in releases: | |
| ver = parse(release) | |
| if not ver.is_prerelease: | |
| version = max(version, ver) | |
| return version | |
| def test_what_latest_packages(): | |
| # pip install requirements-parser | |
| import requirements | |
| import glob | |
| for req_name in ['requirements.txt'] + glob.glob('reqs_optional/req*.txt'): | |
| print("\n File: %s" % req_name, flush=True) | |
| with open(req_name, 'rt') as fd: | |
| for req in requirements.parse(fd): | |
| from importlib.metadata import version | |
| try: | |
| current_version = version(req.name) | |
| latest_version = get_version(req.name) | |
| if str(current_version) != str(latest_version): | |
| print("%s: %s -> %s" % (req.name, current_version, latest_version), flush=True) | |
| except Exception as e: | |
| print("Exception: %s" % str(e), flush=True) | |