Spaces:
Paused
Paused
File size: 3,797 Bytes
8bcf79a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#!/bin/bash
# Dressify - Train ViT Outfit Encoder
# This script trains the ViT outfit compatibility encoder on the Polyvore dataset
set -e # Exit on any error
# Configuration
CONFIG_FILE="configs/outfit.yaml"
DATA_ROOT="${POLYVORE_ROOT:-data/Polyvore}"
EXPORT_DIR="models/exports"
EPOCHS="${EPOCHS:-30}"
BATCH_SIZE="${BATCH_SIZE:-32}"
LR="${LR:-0.0005}"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
echo -e "${BLUE}π Starting ViT Outfit Encoder Training${NC}"
echo "=================================================="
# Check if dataset exists
if [ ! -d "$DATA_ROOT" ]; then
echo -e "${RED}β Dataset not found at $DATA_ROOT${NC}"
echo "Please run dataset preparation first:"
echo " python scripts/prepare_polyvore.py --root $DATA_ROOT --random_split"
exit 1
fi
# Check if ResNet checkpoint exists
RESNET_CHECKPOINT="$EXPORT_DIR/resnet_item_embedder_best.pth"
if [ ! -f "$RESNET_CHECKPOINT" ]; then
echo -e "${RED}β ResNet checkpoint not found at $RESNET_CHECKPOINT${NC}"
echo "Please train ResNet first:"
echo " ./scripts/train_item.sh"
exit 1
fi
echo -e "${GREEN}β
Found ResNet checkpoint: $RESNET_CHECKPOINT${NC}"
# Check if outfit triplets exist
if [ ! -f "$DATA_ROOT/splits/outfit_triplets_train.json" ]; then
echo -e "${YELLOW}β οΈ Outfit triplets not found${NC}"
echo "Creating outfit triplets..."
python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split
fi
# Create export directory
mkdir -p "$EXPORT_DIR"
# Check for existing checkpoints
if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then
echo -e "${GREEN}β
Found existing best checkpoint${NC}"
echo "Starting from existing model..."
START_FROM_CHECKPOINT="--resume"
else
echo -e "${BLUE}π No existing checkpoint found, starting fresh${NC}"
START_FROM_CHECKPOINT=""
fi
# Training command
echo -e "${BLUE}π― Training Configuration:${NC}"
echo " Data Root: $DATA_ROOT"
echo " ResNet Checkpoint: $RESNET_CHECKPOINT"
echo " Epochs: $EPOCHS"
echo " Batch Size: $BATCH_SIZE"
echo " Learning Rate: $LR"
echo " Export Dir: $EXPORT_DIR"
echo ""
# Run training
echo -e "${BLUE}π₯ Starting ViT training...${NC}"
python train_vit_triplet.py \
--data_root "$DATA_ROOT" \
--epochs "$EPOCHS" \
--batch_size "$BATCH_SIZE" \
--lr "$LR" \
--export "$EXPORT_DIR/vit_outfit_model.pth" \
$START_FROM_CHECKPOINT
# Check if training completed successfully
if [ $? -eq 0 ]; then
echo -e "${GREEN}β
Training completed successfully!${NC}"
# List generated files
echo -e "${BLUE}π Generated files:${NC}"
ls -la "$EXPORT_DIR"/vit_*
# Check if best checkpoint exists
if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then
echo -e "${GREEN}π Best checkpoint saved: vit_outfit_model_best.pth${NC}"
fi
# Check metrics
if [ -f "$EXPORT_DIR/vit_metrics.json" ]; then
echo -e "${BLUE}π Training metrics saved: vit_metrics.json${NC}"
echo "Metrics summary:"
python -c "
import json
with open('$EXPORT_DIR/vit_metrics.json') as f:
metrics = json.load(f)
best_loss = metrics.get('best_val_triplet_loss')
if best_loss is not None:
print(f'Best validation triplet loss: {best_loss:.4f}')
else:
print('Best validation loss: N/A')
print(f'Training history: {len(metrics.get(\"history\", []))} epochs')
"
fi
else
echo -e "${RED}β Training failed!${NC}"
exit 1
fi
echo -e "${GREEN}π ViT training script completed!${NC}"
echo ""
echo -e "${BLUE}Next steps:${NC}"
echo "1. Test inference: python app.py"
echo "2. Deploy to HF Space: ./scripts/deploy_space.sh"
echo "3. Push models to HF Hub: python utils/hf_utils.py --action push"
|