File size: 3,797 Bytes
8bcf79a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/bin/bash

# Dressify - Train ViT Outfit Encoder
# This script trains the ViT outfit compatibility encoder on the Polyvore dataset

set -e  # Exit on any error

# Configuration
CONFIG_FILE="configs/outfit.yaml"
DATA_ROOT="${POLYVORE_ROOT:-data/Polyvore}"
EXPORT_DIR="models/exports"
EPOCHS="${EPOCHS:-30}"
BATCH_SIZE="${BATCH_SIZE:-32}"
LR="${LR:-0.0005}"

# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color

echo -e "${BLUE}πŸš€ Starting ViT Outfit Encoder Training${NC}"
echo "=================================================="

# Check if dataset exists
if [ ! -d "$DATA_ROOT" ]; then
    echo -e "${RED}❌ Dataset not found at $DATA_ROOT${NC}"
    echo "Please run dataset preparation first:"
    echo "  python scripts/prepare_polyvore.py --root $DATA_ROOT --random_split"
    exit 1
fi

# Check if ResNet checkpoint exists
RESNET_CHECKPOINT="$EXPORT_DIR/resnet_item_embedder_best.pth"
if [ ! -f "$RESNET_CHECKPOINT" ]; then
    echo -e "${RED}❌ ResNet checkpoint not found at $RESNET_CHECKPOINT${NC}"
    echo "Please train ResNet first:"
    echo "  ./scripts/train_item.sh"
    exit 1
fi

echo -e "${GREEN}βœ… Found ResNet checkpoint: $RESNET_CHECKPOINT${NC}"

# Check if outfit triplets exist
if [ ! -f "$DATA_ROOT/splits/outfit_triplets_train.json" ]; then
    echo -e "${YELLOW}⚠️  Outfit triplets not found${NC}"
    echo "Creating outfit triplets..."
    python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split
fi

# Create export directory
mkdir -p "$EXPORT_DIR"

# Check for existing checkpoints
if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then
    echo -e "${GREEN}βœ… Found existing best checkpoint${NC}"
    echo "Starting from existing model..."
    START_FROM_CHECKPOINT="--resume"
else
    echo -e "${BLUE}πŸ†• No existing checkpoint found, starting fresh${NC}"
    START_FROM_CHECKPOINT=""
fi

# Training command
echo -e "${BLUE}🎯 Training Configuration:${NC}"
echo "  Data Root: $DATA_ROOT"
echo "  ResNet Checkpoint: $RESNET_CHECKPOINT"
echo "  Epochs: $EPOCHS"
echo "  Batch Size: $BATCH_SIZE"
echo "  Learning Rate: $LR"
echo "  Export Dir: $EXPORT_DIR"
echo ""

# Run training
echo -e "${BLUE}πŸ”₯ Starting ViT training...${NC}"
python train_vit_triplet.py \
    --data_root "$DATA_ROOT" \
    --epochs "$EPOCHS" \
    --batch_size "$BATCH_SIZE" \
    --lr "$LR" \
    --export "$EXPORT_DIR/vit_outfit_model.pth" \
    $START_FROM_CHECKPOINT

# Check if training completed successfully
if [ $? -eq 0 ]; then
    echo -e "${GREEN}βœ… Training completed successfully!${NC}"
    
    # List generated files
    echo -e "${BLUE}πŸ“ Generated files:${NC}"
    ls -la "$EXPORT_DIR"/vit_*
    
    # Check if best checkpoint exists
    if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then
        echo -e "${GREEN}πŸ† Best checkpoint saved: vit_outfit_model_best.pth${NC}"
    fi
    
    # Check metrics
    if [ -f "$EXPORT_DIR/vit_metrics.json" ]; then
        echo -e "${BLUE}πŸ“Š Training metrics saved: vit_metrics.json${NC}"
        echo "Metrics summary:"
        python -c "
import json
with open('$EXPORT_DIR/vit_metrics.json') as f:
    metrics = json.load(f)
best_loss = metrics.get('best_val_triplet_loss')
if best_loss is not None:
    print(f'Best validation triplet loss: {best_loss:.4f}')
else:
    print('Best validation loss: N/A')
print(f'Training history: {len(metrics.get(\"history\", []))} epochs')
"
    fi
    
else
    echo -e "${RED}❌ Training failed!${NC}"
    exit 1
fi

echo -e "${GREEN}πŸŽ‰ ViT training script completed!${NC}"
echo ""
echo -e "${BLUE}Next steps:${NC}"
echo "1. Test inference: python app.py"
echo "2. Deploy to HF Space: ./scripts/deploy_space.sh"
echo "3. Push models to HF Hub: python utils/hf_utils.py --action push"