| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import platform |
|
|
| import accelerate |
| import datasets |
| import peft |
| import torch |
| import transformers |
| import trl |
| from transformers.utils import is_torch_cuda_available, is_torch_npu_available |
|
|
|
|
| VERSION = "0.9.4.dev0" |
|
|
|
|
| def print_env() -> None: |
| info = { |
| "`llamafactory` version": VERSION, |
| "Platform": platform.platform(), |
| "Python version": platform.python_version(), |
| "PyTorch version": torch.__version__, |
| "Transformers version": transformers.__version__, |
| "Datasets version": datasets.__version__, |
| "Accelerate version": accelerate.__version__, |
| "PEFT version": peft.__version__, |
| "TRL version": trl.__version__, |
| } |
|
|
| if is_torch_cuda_available(): |
| info["PyTorch version"] += " (GPU)" |
| info["GPU type"] = torch.cuda.get_device_name() |
| info["GPU number"] = torch.cuda.device_count() |
| info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB" |
|
|
| if is_torch_npu_available(): |
| info["PyTorch version"] += " (NPU)" |
| info["NPU type"] = torch.npu.get_device_name() |
| info["CANN version"] = torch.version.cann |
|
|
| try: |
| import deepspeed |
|
|
| info["DeepSpeed version"] = deepspeed.__version__ |
| except Exception: |
| pass |
|
|
| try: |
| import bitsandbytes |
|
|
| info["Bitsandbytes version"] = bitsandbytes.__version__ |
| except Exception: |
| pass |
|
|
| try: |
| import vllm |
|
|
| info["vLLM version"] = vllm.__version__ |
| except Exception: |
| pass |
|
|
| try: |
| import subprocess |
|
|
| commit_info = subprocess.run(["git", "rev-parse", "HEAD"], capture_output=True, text=True, check=True) |
| commit_hash = commit_info.stdout.strip() |
| info["Git commit"] = commit_hash |
| except Exception: |
| pass |
|
|
| if os.path.exists("data"): |
| info["Default data directory"] = "detected" |
| else: |
| info["Default data directory"] = "not detected" |
|
|
| print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n") |
|
|