| FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime |
|
|
| WORKDIR /app |
|
|
| RUN apt-get update && apt-get install -y \ |
| build-essential \ |
| git \ |
| python3-dev \ |
| libsndfile1 \ |
| && rm -rf /var/lib/apt/lists/* |
|
|
| COPY requirements.txt . |
| RUN pip install -r requirements.txt |
|
|
| COPY . . |
|
|
| RUN mkdir -p checkpoints/hf_cache runs |
|
|
| ENV PYTHONPATH=/app |
| ENV HF_HUB_CACHE=/app/checkpoints/hf_cache |
|
|
| RUN echo '#!/bin/bash\n\ |
| python train.py \ |
| --config ./configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml \ |
| --dataset-dir dataset \ |
| --run-name training-run \ |
| --batch-size 2 \ |
| --max-steps 300 \ |
| --max-epochs 1000 \ |
| --save-every 100 \ |
| --num-workers 0' > entrypoint.sh \ |
| && chmod +x entrypoint.sh |
|
|
| ENTRYPOINT ["./entrypoint.sh"] |