#!/bin/bash # Dressify - Train ViT Outfit Encoder # This script trains the ViT outfit compatibility encoder on the Polyvore dataset set -e # Exit on any error # Configuration CONFIG_FILE="configs/outfit.yaml" DATA_ROOT="${POLYVORE_ROOT:-data/Polyvore}" EXPORT_DIR="models/exports" EPOCHS="${EPOCHS:-30}" BATCH_SIZE="${BATCH_SIZE:-32}" LR="${LR:-0.0005}" # Colors for output RED='\033[0;31m' GREEN='\033[0;32m' YELLOW='\033[1;33m' BLUE='\033[0;34m' NC='\033[0m' # No Color echo -e "${BLUE}🚀 Starting ViT Outfit Encoder Training${NC}" echo "==================================================" # Check if dataset exists if [ ! -d "$DATA_ROOT" ]; then echo -e "${RED}❌ Dataset not found at $DATA_ROOT${NC}" echo "Please run dataset preparation first:" echo " python scripts/prepare_polyvore.py --root $DATA_ROOT --random_split" exit 1 fi # Check if ResNet checkpoint exists RESNET_CHECKPOINT="$EXPORT_DIR/resnet_item_embedder_best.pth" if [ ! -f "$RESNET_CHECKPOINT" ]; then echo -e "${RED}❌ ResNet checkpoint not found at $RESNET_CHECKPOINT${NC}" echo "Please train ResNet first:" echo " ./scripts/train_item.sh" exit 1 fi echo -e "${GREEN}✅ Found ResNet checkpoint: $RESNET_CHECKPOINT${NC}" # Check if outfit triplets exist if [ ! -f "$DATA_ROOT/splits/outfit_triplets_train.json" ]; then echo -e "${YELLOW}⚠️ Outfit triplets not found${NC}" echo "Creating outfit triplets..." python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split fi # Create export directory mkdir -p "$EXPORT_DIR" # Check for existing checkpoints if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then echo -e "${GREEN}✅ Found existing best checkpoint${NC}" echo "Starting from existing model..." START_FROM_CHECKPOINT="--resume" else echo -e "${BLUE}🆕 No existing checkpoint found, starting fresh${NC}" START_FROM_CHECKPOINT="" fi # Training command echo -e "${BLUE}🎯 Training Configuration:${NC}" echo " Data Root: $DATA_ROOT" echo " ResNet Checkpoint: $RESNET_CHECKPOINT" echo " Epochs: $EPOCHS" echo " Batch Size: $BATCH_SIZE" echo " Learning Rate: $LR" echo " Export Dir: $EXPORT_DIR" echo "" # Run training echo -e "${BLUE}🔥 Starting ViT training...${NC}" python train_vit_triplet.py \ --data_root "$DATA_ROOT" \ --epochs "$EPOCHS" \ --batch_size "$BATCH_SIZE" \ --lr "$LR" \ --export "$EXPORT_DIR/vit_outfit_model.pth" \ $START_FROM_CHECKPOINT # Check if training completed successfully if [ $? -eq 0 ]; then echo -e "${GREEN}✅ Training completed successfully!${NC}" # List generated files echo -e "${BLUE}📁 Generated files:${NC}" ls -la "$EXPORT_DIR"/vit_* # Check if best checkpoint exists if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then echo -e "${GREEN}🏆 Best checkpoint saved: vit_outfit_model_best.pth${NC}" fi # Check metrics if [ -f "$EXPORT_DIR/vit_metrics.json" ]; then echo -e "${BLUE}📊 Training metrics saved: vit_metrics.json${NC}" echo "Metrics summary:" python -c " import json with open('$EXPORT_DIR/vit_metrics.json') as f: metrics = json.load(f) best_loss = metrics.get('best_val_triplet_loss') if best_loss is not None: print(f'Best validation triplet loss: {best_loss:.4f}') else: print('Best validation loss: N/A') print(f'Training history: {len(metrics.get(\"history\", []))} epochs') " fi else echo -e "${RED}❌ Training failed!${NC}" exit 1 fi echo -e "${GREEN}🎉 ViT training script completed!${NC}" echo "" echo -e "${BLUE}Next steps:${NC}" echo "1. Test inference: python app.py" echo "2. Deploy to HF Space: ./scripts/deploy_space.sh" echo "3. Push models to HF Hub: python utils/hf_utils.py --action push"