File size: 2,444 Bytes
0d00bbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# HiF8 QAT Training Environment
# Base: CUDA 12.4 devel(与 torch 2.6.0+cu124 匹配,需要 nvcc 编译 quant_cy 扩展)
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04

# ── 系统依赖 ──────────────────────────────────────────────────────────────────
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y \
    python3.10 python3.10-dev python3-pip \
    git wget curl \
    libgomp1 \
    && rm -rf /var/lib/apt/lists/*

RUN ln -sf /usr/bin/python3.10 /usr/bin/python3 && \
    ln -sf /usr/bin/python3 /usr/bin/python && \
    pip install --upgrade pip

# ── PyTorch(CUDA 12.4)──────────────────────────────────────────────────────
RUN pip install \
    torch==2.6.0+cu124 \
    torchaudio==2.6.0+cu124 \
    torchvision==0.21.0+cu124 \
    --index-url https://download.pytorch.org/whl/cu124

# ── Python 依赖 ───────────────────────────────────────────────────────────────
RUN pip install \
    transformers==4.53.3 \
    accelerate==1.13.0 \
    datasets==4.8.4 \
    huggingface_hub==0.36.2 \
    tokenizers==0.21.4 \
    safetensors==0.7.0 \
    sentencepiece==0.2.1 \
    numpy==2.2.6 \
    pyarrow==23.0.1 \
    fsspec==2026.2.0 \
    requests==2.33.1 \
    packaging==26.0 \
    triton==3.2.0 \
    sympy==1.13.1 \
    antlr4-python3-runtime==4.11.0 \
    lm-eval==0.4.4

# ── 复制项目代码 ──────────────────────────────────────────────────────────────
WORKDIR /workspace
COPY . /workspace/tracy

# ── 编译 HiFloat8 quant_cy CUDA 扩展 ─────────────────────────────────────────
WORKDIR /workspace/tracy/HiFloat8/hif8_cuda
RUN bash build.sh

# ── 设置 PYTHONPATH,使 hif8.py / quant_cy 可直接 import ─────────────────────
ENV PYTHONPATH="/workspace/tracy/HiFloat8/hif8_cuda:/workspace/tracy/pangu_hif8_pretrain:${PYTHONPATH}"

WORKDIR /workspace/tracy