Spaces:
Paused
Paused
| # Dressify - Train ViT Outfit Encoder | |
| # This script trains the ViT outfit compatibility encoder on the Polyvore dataset | |
| set -e # Exit on any error | |
| # Configuration | |
| CONFIG_FILE="configs/outfit.yaml" | |
| DATA_ROOT="${POLYVORE_ROOT:-data/Polyvore}" | |
| EXPORT_DIR="models/exports" | |
| EPOCHS="${EPOCHS:-30}" | |
| BATCH_SIZE="${BATCH_SIZE:-32}" | |
| LR="${LR:-0.0005}" | |
| # Colors for output | |
| RED='\033[0;31m' | |
| GREEN='\033[0;32m' | |
| YELLOW='\033[1;33m' | |
| BLUE='\033[0;34m' | |
| NC='\033[0m' # No Color | |
| echo -e "${BLUE}π Starting ViT Outfit Encoder Training${NC}" | |
| echo "==================================================" | |
| # Check if dataset exists | |
| if [ ! -d "$DATA_ROOT" ]; then | |
| echo -e "${RED}β Dataset not found at $DATA_ROOT${NC}" | |
| echo "Please run dataset preparation first:" | |
| echo " python scripts/prepare_polyvore.py --root $DATA_ROOT --random_split" | |
| exit 1 | |
| fi | |
| # Check if ResNet checkpoint exists | |
| RESNET_CHECKPOINT="$EXPORT_DIR/resnet_item_embedder_best.pth" | |
| if [ ! -f "$RESNET_CHECKPOINT" ]; then | |
| echo -e "${RED}β ResNet checkpoint not found at $RESNET_CHECKPOINT${NC}" | |
| echo "Please train ResNet first:" | |
| echo " ./scripts/train_item.sh" | |
| exit 1 | |
| fi | |
| echo -e "${GREEN}β Found ResNet checkpoint: $RESNET_CHECKPOINT${NC}" | |
| # Check if outfit triplets exist | |
| if [ ! -f "$DATA_ROOT/splits/outfit_triplets_train.json" ]; then | |
| echo -e "${YELLOW}β οΈ Outfit triplets not found${NC}" | |
| echo "Creating outfit triplets..." | |
| python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split | |
| fi | |
| # Create export directory | |
| mkdir -p "$EXPORT_DIR" | |
| # Check for existing checkpoints | |
| if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then | |
| echo -e "${GREEN}β Found existing best checkpoint${NC}" | |
| echo "Starting from existing model..." | |
| START_FROM_CHECKPOINT="--resume" | |
| else | |
| echo -e "${BLUE}π No existing checkpoint found, starting fresh${NC}" | |
| START_FROM_CHECKPOINT="" | |
| fi | |
| # Training command | |
| echo -e "${BLUE}π― Training Configuration:${NC}" | |
| echo " Data Root: $DATA_ROOT" | |
| echo " ResNet Checkpoint: $RESNET_CHECKPOINT" | |
| echo " Epochs: $EPOCHS" | |
| echo " Batch Size: $BATCH_SIZE" | |
| echo " Learning Rate: $LR" | |
| echo " Export Dir: $EXPORT_DIR" | |
| echo "" | |
| # Run training | |
| echo -e "${BLUE}π₯ Starting ViT training...${NC}" | |
| python train_vit_triplet.py \ | |
| --data_root "$DATA_ROOT" \ | |
| --epochs "$EPOCHS" \ | |
| --batch_size "$BATCH_SIZE" \ | |
| --lr "$LR" \ | |
| --export "$EXPORT_DIR/vit_outfit_model.pth" \ | |
| $START_FROM_CHECKPOINT | |
| # Check if training completed successfully | |
| if [ $? -eq 0 ]; then | |
| echo -e "${GREEN}β Training completed successfully!${NC}" | |
| # List generated files | |
| echo -e "${BLUE}π Generated files:${NC}" | |
| ls -la "$EXPORT_DIR"/vit_* | |
| # Check if best checkpoint exists | |
| if [ -f "$EXPORT_DIR/vit_outfit_model_best.pth" ]; then | |
| echo -e "${GREEN}π Best checkpoint saved: vit_outfit_model_best.pth${NC}" | |
| fi | |
| # Check metrics | |
| if [ -f "$EXPORT_DIR/vit_metrics.json" ]; then | |
| echo -e "${BLUE}π Training metrics saved: vit_metrics.json${NC}" | |
| echo "Metrics summary:" | |
| python -c " | |
| import json | |
| with open('$EXPORT_DIR/vit_metrics.json') as f: | |
| metrics = json.load(f) | |
| best_loss = metrics.get('best_val_triplet_loss') | |
| if best_loss is not None: | |
| print(f'Best validation triplet loss: {best_loss:.4f}') | |
| else: | |
| print('Best validation loss: N/A') | |
| print(f'Training history: {len(metrics.get(\"history\", []))} epochs') | |
| " | |
| fi | |
| else | |
| echo -e "${RED}β Training failed!${NC}" | |
| exit 1 | |
| fi | |
| echo -e "${GREEN}π ViT training script completed!${NC}" | |
| echo "" | |
| echo -e "${BLUE}Next steps:${NC}" | |
| echo "1. Test inference: python app.py" | |
| echo "2. Deploy to HF Space: ./scripts/deploy_space.sh" | |
| echo "3. Push models to HF Hub: python utils/hf_utils.py --action push" | |