File size: 748 Bytes
6d8a316
 
ce5bcf8
 
6d8a316
 
 
 
 
 
 
 
 
 
 
 
 
ce66137
6d8a316
 
 
a52065b
 
 
ce5bcf8
 
 
6d8a316
b491772
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-devel

# 書き込み可能なワークディレクトリ
WORKDIR /tmp/training

# 基本パッケージ
RUN apt-get update && apt-get install -y \
    git \
    curl \
    && rm -rf /var/lib/apt/lists/*

# Python依存関係
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 学習スクリプト
COPY train.py .
COPY train_multi_gpu.py .

# HFトークンは環境変数で渡す
ENV HF_TOKEN=""
ENV HF_HOME=/tmp/hf_cache
ENV TRANSFORMERS_CACHE=/tmp/hf_cache

# ディレクトリ作成
RUN mkdir -p /tmp/hf_cache /tmp/training/checkpoints /tmp/training/output && \
    chmod -R 777 /tmp/hf_cache /tmp/training

# シングルGPU学習 (L40S 48GB)
CMD ["python", "train.py"]