# HiF8 QAT Training Environment # Base: CUDA 12.4 devel(与 torch 2.6.0+cu124 匹配,需要 nvcc 编译 quant_cy 扩展) FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 # ── 系统依赖 ────────────────────────────────────────────────────────────────── ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && apt-get install -y \ python3.10 python3.10-dev python3-pip \ git wget curl \ libgomp1 \ && rm -rf /var/lib/apt/lists/* RUN ln -sf /usr/bin/python3.10 /usr/bin/python3 && \ ln -sf /usr/bin/python3 /usr/bin/python && \ pip install --upgrade pip # ── PyTorch(CUDA 12.4)────────────────────────────────────────────────────── RUN pip install \ torch==2.6.0+cu124 \ torchaudio==2.6.0+cu124 \ torchvision==0.21.0+cu124 \ --index-url https://download.pytorch.org/whl/cu124 # ── Python 依赖 ─────────────────────────────────────────────────────────────── RUN pip install \ transformers==4.53.3 \ accelerate==1.13.0 \ datasets==4.8.4 \ huggingface_hub==0.36.2 \ tokenizers==0.21.4 \ safetensors==0.7.0 \ sentencepiece==0.2.1 \ numpy==2.2.6 \ pyarrow==23.0.1 \ fsspec==2026.2.0 \ requests==2.33.1 \ packaging==26.0 \ triton==3.2.0 \ sympy==1.13.1 \ antlr4-python3-runtime==4.11.0 \ lm-eval==0.4.4 # ── 复制项目代码 ────────────────────────────────────────────────────────────── WORKDIR /workspace COPY . /workspace/tracy # ── 编译 HiFloat8 quant_cy CUDA 扩展 ───────────────────────────────────────── WORKDIR /workspace/tracy/HiFloat8/hif8_cuda RUN bash build.sh # ── 设置 PYTHONPATH,使 hif8.py / quant_cy 可直接 import ───────────────────── ENV PYTHONPATH="/workspace/tracy/HiFloat8/hif8_cuda:/workspace/tracy/pangu_hif8_pretrain:${PYTHONPATH}" WORKDIR /workspace/tracy