recomendation / scripts /train_item.sh
Ali Mohsin
more try
8bcf79a
#!/bin/bash
# Dressify - Train ResNet Item Embedder
# This script trains the ResNet50 item embedder on the Polyvore dataset
set -e # Exit on any error
# Configuration
CONFIG_FILE="configs/item.yaml"
DATA_ROOT="${POLYVORE_ROOT:-data/Polyvore}"
EXPORT_DIR="models/exports"
EPOCHS="${EPOCHS:-20}"
BATCH_SIZE="${BATCH_SIZE:-64}"
LR="${LR:-0.001}"
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
echo -e "${BLUE}πŸš€ Starting ResNet Item Embedder Training${NC}"
echo "=================================================="
# Check if dataset exists
if [ ! -d "$DATA_ROOT" ]; then
echo -e "${YELLOW}⚠️ Dataset not found at $DATA_ROOT${NC}"
echo "Running dataset preparation..."
python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split
fi
# Check if splits exist
if [ ! -f "$DATA_ROOT/splits/train.json" ]; then
echo -e "${YELLOW}⚠️ Training splits not found${NC}"
echo "Creating splits..."
python scripts/prepare_polyvore.py --root "$DATA_ROOT" --random_split
fi
# Create export directory
mkdir -p "$EXPORT_DIR"
# Check for existing checkpoints
if [ -f "$EXPORT_DIR/resnet_item_embedder_best.pth" ]; then
echo -e "${GREEN}βœ… Found existing best checkpoint${NC}"
echo "Starting from existing model..."
START_FROM_CHECKPOINT="--resume"
else
echo -e "${BLUE}πŸ†• No existing checkpoint found, starting fresh${NC}"
START_FROM_CHECKPOINT=""
fi
# Training command
echo -e "${BLUE}🎯 Training Configuration:${NC}"
echo " Data Root: $DATA_ROOT"
echo " Epochs: $EPOCHS"
echo " Batch Size: $BATCH_SIZE"
echo " Learning Rate: $LR"
echo " Export Dir: $EXPORT_DIR"
echo ""
# Run training
echo -e "${BLUE}πŸ”₯ Starting training...${NC}"
python train_resnet.py \
--data_root "$DATA_ROOT" \
--epochs "$EPOCHS" \
--batch_size "$BATCH_SIZE" \
--lr "$LR" \
--out "$EXPORT_DIR/resnet_item_embedder.pth" \
$START_FROM_CHECKPOINT
# Check if training completed successfully
if [ $? -eq 0 ]; then
echo -e "${GREEN}βœ… Training completed successfully!${NC}"
# List generated files
echo -e "${BLUE}πŸ“ Generated files:${NC}"
ls -la "$EXPORT_DIR"/resnet_*
# Check if best checkpoint exists
if [ -f "$EXPORT_DIR/resnet_item_embedder_best.pth" ]; then
echo -e "${GREEN}πŸ† Best checkpoint saved: resnet_item_embedder_best.pth${NC}"
fi
# Check metrics
if [ -f "$EXPORT_DIR/resnet_metrics.json" ]; then
echo -e "${BLUE}πŸ“Š Training metrics saved: resnet_metrics.json${NC}"
echo "Metrics summary:"
python -c "
import json
with open('$EXPORT_DIR/resnet_metrics.json') as f:
metrics = json.load(f)
print(f'Best triplet loss: {metrics.get(\"best_triplet_loss\", \"N/A\"):.4f}')
print(f'Training history: {len(metrics.get(\"history\", []))} epochs')
"
fi
else
echo -e "${RED}❌ Training failed!${NC}"
exit 1
fi
echo -e "${GREEN}πŸŽ‰ ResNet training script completed!${NC}"
echo ""
echo -e "${BLUE}Next steps:${NC}"
echo "1. Train ViT outfit encoder: ./scripts/train_outfit.sh"
echo "2. Test inference: python app.py"
echo "3. Deploy to HF Space: ./scripts/deploy_space.sh"