Spaces:
Paused
Paused
| # Dressify - Train ResNet Item Embedder | |
| # This script trains the ResNet50 item embedder on the Polyvore dataset | |
| set -e # Exit on any error | |
| # Configuration | |
| CONFIG_FILE="configs/item.yaml" | |
| DATA_ROOT="${POLYVORE_ROOT:-data/Polyvore}" | |
| EXPORT_DIR="models/exports" | |
| EPOCHS="${EPOCHS:-20}" | |
| BATCH_SIZE="${BATCH_SIZE:-64}" | |
| LR="${LR:-0.001}" | |
| # Colors for output | |
| RED='\033[0;31m' | |
| GREEN='\033[0;32m' | |
| YELLOW='\033[1;33m' | |
| BLUE='\033[0;34m' | |
| NC='\033[0m' # No Color | |
| echo -e "${BLUE}π Starting ResNet Item Embedder Training${NC}" | |
| echo "==================================================" | |
| # Check if dataset exists | |
| if [ ! -d "$DATA_ROOT" ]; then | |
| echo -e "${YELLOW}β οΈ Dataset not found at $DATA_ROOT${NC}" | |
| echo "Running dataset preparation..." | |
| python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split | |
| fi | |
| # Check if splits exist | |
| if [ ! -f "$DATA_ROOT/splits/train.json" ]; then | |
| echo -e "${YELLOW}β οΈ Training splits not found${NC}" | |
| echo "Creating splits..." | |
| python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split | |
| fi | |
| # Create export directory | |
| mkdir -p "$EXPORT_DIR" | |
| # Check for existing checkpoints | |
| if [ -f "$EXPORT_DIR/resnet_item_embedder_best.pth" ]; then | |
| echo -e "${GREEN}β Found existing best checkpoint${NC}" | |
| echo "Starting from existing model..." | |
| START_FROM_CHECKPOINT="--resume" | |
| else | |
| echo -e "${BLUE}π No existing checkpoint found, starting fresh${NC}" | |
| START_FROM_CHECKPOINT="" | |
| fi | |
| # Training command | |
| echo -e "${BLUE}π― Training Configuration:${NC}" | |
| echo " Data Root: $DATA_ROOT" | |
| echo " Epochs: $EPOCHS" | |
| echo " Batch Size: $BATCH_SIZE" | |
| echo " Learning Rate: $LR" | |
| echo " Export Dir: $EXPORT_DIR" | |
| echo "" | |
| # Run training | |
| echo -e "${BLUE}π₯ Starting training...${NC}" | |
| python train_resnet.py \ | |
| --data_root "$DATA_ROOT" \ | |
| --epochs "$EPOCHS" \ | |
| --batch_size "$BATCH_SIZE" \ | |
| --lr "$LR" \ | |
| --out "$EXPORT_DIR/resnet_item_embedder.pth" \ | |
| $START_FROM_CHECKPOINT | |
| # Check if training completed successfully | |
| if [ $? -eq 0 ]; then | |
| echo -e "${GREEN}β Training completed successfully!${NC}" | |
| # List generated files | |
| echo -e "${BLUE}π Generated files:${NC}" | |
| ls -la "$EXPORT_DIR"/resnet_* | |
| # Check if best checkpoint exists | |
| if [ -f "$EXPORT_DIR/resnet_item_embedder_best.pth" ]; then | |
| echo -e "${GREEN}π Best checkpoint saved: resnet_item_embedder_best.pth${NC}" | |
| fi | |
| # Check metrics | |
| if [ -f "$EXPORT_DIR/resnet_metrics.json" ]; then | |
| echo -e "${BLUE}π Training metrics saved: resnet_metrics.json${NC}" | |
| echo "Metrics summary:" | |
| python -c " | |
| import json | |
| with open('$EXPORT_DIR/resnet_metrics.json') as f: | |
| metrics = json.load(f) | |
| print(f'Best triplet loss: {metrics.get(\"best_triplet_loss\", \"N/A\"):.4f}') | |
| print(f'Training history: {len(metrics.get(\"history\", []))} epochs') | |
| " | |
| fi | |
| else | |
| echo -e "${RED}β Training failed!${NC}" | |
| exit 1 | |
| fi | |
| echo -e "${GREEN}π ResNet training script completed!${NC}" | |
| echo "" | |
| echo -e "${BLUE}Next steps:${NC}" | |
| echo "1. Train ViT outfit encoder: ./scripts/train_outfit.sh" | |
| echo "2. Test inference: python app.py" | |
| echo "3. Deploy to HF Space: ./scripts/deploy_space.sh" | |