JensLundsgaard commited on
Commit
ac7b719
·
verified ·
1 Parent(s): 9bd3a1c

Upload train_model.sh with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_model.sh +160 -0
train_model.sh ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Parse positional arguments
4
+ MODE="${1:-convlstm_latent_split}" # Default to convlstm_latent_split
5
+ LOSS_TYPE="${2:-l1}" # Default to l1
6
+
7
+ # Shift past the first two arguments to get ablation parameters
8
+ shift 2 2>/dev/null || true
9
+
10
+ # Default values (can be overridden by environment variables or command-line args)
11
+ MS_SSIM_WEIGHT="${MS_SSIM_WEIGHT:-0.5}"
12
+ REC_WEIGHT="${REC_WEIGHT:-0.5}"
13
+ TEMPORAL_WEIGHT="${TEMPORAL_WEIGHT:-0.1}"
14
+ DROPOUT_RATE="${DROPOUT_RATE:-0.1}"
15
+ USE_CONVLSTM="${USE_CONVLSTM:-true}"
16
+ USE_RESIDUAL="${USE_RESIDUAL:-true}"
17
+ USE_BATCHNORM="${USE_BATCHNORM:-true}"
18
+
19
+ # Parse command-line arguments (override defaults and environment variables)
20
+ EXTRA_ARGS=""
21
+ while [[ $# -gt 0 ]]; do
22
+ case $1 in
23
+ --ms-ssim-weight)
24
+ MS_SSIM_WEIGHT="$2"
25
+ shift 2
26
+ ;;
27
+ --rec-weight)
28
+ REC_WEIGHT="$2"
29
+ shift 2
30
+ ;;
31
+ --temporal-weight)
32
+ TEMPORAL_WEIGHT="$2"
33
+ shift 2
34
+ ;;
35
+ --dropout-rate)
36
+ DROPOUT_RATE="$2"
37
+ shift 2
38
+ ;;
39
+ --no-convlstm)
40
+ USE_CONVLSTM="false"
41
+ shift
42
+ ;;
43
+ --no-residual)
44
+ USE_RESIDUAL="false"
45
+ shift
46
+ ;;
47
+ --no-batchnorm)
48
+ USE_BATCHNORM="false"
49
+ shift
50
+ ;;
51
+ *)
52
+ # Pass through any unrecognized arguments
53
+ EXTRA_ARGS="$EXTRA_ARGS $1"
54
+ shift
55
+ ;;
56
+ esac
57
+ done
58
+
59
+ pip install huggingface_hub wandb safetensors
60
+ HF_KEY=$(head -n 1 api_keys.txt)
61
+ export HF_TOKEN=$HF_KEY
62
+ WANDB_KEY=$(tail -n 1 api_keys.txt)
63
+ export WANDB_KEY=$WANDB_KEY
64
+ export TORCH_DISTRIBUTED_DEBUG=DETAIL
65
+ export NCCL_DEBUG=INFO
66
+ tar -zxf embryo_dataset.tar.gz
67
+
68
+ # Create training configuration file for reproducibility
69
+ cat > training_config.txt << EOF
70
+ ABLATION STUDY - Training Configuration
71
+ ========================================
72
+ Date: $(date)
73
+ Script: train_model.sh
74
+ Mode: $MODE
75
+
76
+ Loss Configuration:
77
+ - Loss Type: $LOSS_TYPE
78
+ - MS-SSIM Weight: $MS_SSIM_WEIGHT $([ "$MS_SSIM_WEIGHT" = "0" ] && echo "(DISABLED)" || echo "")
79
+ - Reconstruction Weight: $REC_WEIGHT $([ "$REC_WEIGHT" = "0" ] && echo "(DISABLED)" || echo "")
80
+ - Temporal Weight: $TEMPORAL_WEIGHT $([ "$TEMPORAL_WEIGHT" = "0" ] && echo "(DISABLED)" || echo "")
81
+
82
+ Model Architecture:
83
+ - ConvLSTM: $([ "$USE_CONVLSTM" = "true" ] && echo "ENABLED" || echo "DISABLED")
84
+ - Residual Connections: $([ "$USE_RESIDUAL" = "true" ] && echo "ENABLED" || echo "DISABLED")
85
+ - Batch Normalization: $([ "$USE_BATCHNORM" = "true" ] && echo "ENABLED" || echo "DISABLED")
86
+ - Dropout Rate: $DROPOUT_RATE $([ "$DROPOUT_RATE" = "0" ] && echo "(DISABLED)" || echo "")
87
+
88
+ Latent Split: $([ "$MODE" = "convlstm_latent_split" ] && echo "ENABLED (2048 empty + 2048 embryo)" || echo "DISABLED")
89
+
90
+ Command: python train.py $MODE \\
91
+ --loss-type $LOSS_TYPE \\
92
+ --ms-ssim-weight $MS_SSIM_WEIGHT \\
93
+ --rec-weight $REC_WEIGHT \\
94
+ --temporal-weight $TEMPORAL_WEIGHT \\
95
+ --dropout-rate $DROPOUT_RATE \\
96
+ $([ "$USE_CONVLSTM" = "false" ] && echo "--no-convlstm" || echo "") \\
97
+ $([ "$USE_RESIDUAL" = "false" ] && echo "--no-residual" || echo "") \\
98
+ $([ "$USE_BATCHNORM" = "false" ] && echo "--no-batchnorm" || echo "")
99
+ EOF
100
+
101
+ echo "========================================="
102
+ echo "ABLATION STUDY - Training Configuration"
103
+ echo "========================================="
104
+ echo ""
105
+ echo "Mode: $MODE"
106
+ echo ""
107
+ echo "Loss Configuration:"
108
+ echo " - Loss Type: $LOSS_TYPE"
109
+ echo " - MS-SSIM Weight: $MS_SSIM_WEIGHT $([ "$MS_SSIM_WEIGHT" = "0" ] && echo "(DISABLED)" || echo "")"
110
+ echo " - Reconstruction Weight: $REC_WEIGHT $([ "$REC_WEIGHT" = "0" ] && echo "(DISABLED)" || echo "")"
111
+ echo " - Temporal Weight: $TEMPORAL_WEIGHT $([ "$TEMPORAL_WEIGHT" = "0" ] && echo "(DISABLED)" || echo "")"
112
+ echo ""
113
+ echo "Model Architecture:"
114
+ echo " - ConvLSTM: $([ "$USE_CONVLSTM" = "true" ] && echo "ENABLED" || echo "DISABLED")"
115
+ echo " - Residual Connections: $([ "$USE_RESIDUAL" = "true" ] && echo "ENABLED" || echo "DISABLED")"
116
+ echo " - Batch Normalization: $([ "$USE_BATCHNORM" = "true" ] && echo "ENABLED" || echo "DISABLED")"
117
+ echo " - Dropout Rate: $DROPOUT_RATE $([ "$DROPOUT_RATE" = "0" ] && echo "(DISABLED)" || echo "")"
118
+ echo ""
119
+ if [ "$MODE" = "convlstm_latent_split" ]; then
120
+ echo "Latent Split: ENABLED"
121
+ echo " - Empty Well Latent: 2048 (first half)"
122
+ echo " - Embryo Latent: 2048 (second half)"
123
+ else
124
+ echo "Latent Split: DISABLED"
125
+ fi
126
+ echo "========================================="
127
+ cat training_config.txt
128
+ echo "========================================="
129
+
130
+ # Build command with ablation arguments
131
+ CMD="python train.py $MODE --loss-type $LOSS_TYPE"
132
+ CMD="$CMD --ms-ssim-weight $MS_SSIM_WEIGHT"
133
+ CMD="$CMD --rec-weight $REC_WEIGHT"
134
+ CMD="$CMD --temporal-weight $TEMPORAL_WEIGHT"
135
+ CMD="$CMD --dropout-rate $DROPOUT_RATE"
136
+
137
+ if [ "$USE_CONVLSTM" = "false" ]; then
138
+ CMD="$CMD --no-convlstm"
139
+ fi
140
+
141
+ if [ "$USE_RESIDUAL" = "false" ]; then
142
+ CMD="$CMD --no-residual"
143
+ fi
144
+
145
+ if [ "$USE_BATCHNORM" = "false" ]; then
146
+ CMD="$CMD --no-batchnorm"
147
+ fi
148
+
149
+ # Add any extra arguments that were passed through
150
+ if [ -n "$EXTRA_ARGS" ]; then
151
+ CMD="$CMD $EXTRA_ARGS"
152
+ fi
153
+
154
+ # Run training with specified configuration
155
+ echo "Executing: $CMD"
156
+ eval $CMD
157
+
158
+ #python -m torch.distributed.launch --nproc_per_node=4 --use_env train.py
159
+
160
+ rm -r embryo_dataset