Spaces:
Paused
Paused
| """ | |
| Advanced Training UI Components for Dressify | |
| Provides comprehensive parameter controls for both ResNet and ViT training | |
| """ | |
| import gradio as gr | |
| import os | |
| import subprocess | |
| import threading | |
| import json | |
| from typing import Dict, Any | |
| def create_advanced_training_interface(): | |
| """Create the advanced training interface with all parameter controls.""" | |
| with gr.Blocks(title="Advanced Training Control") as training_interface: | |
| gr.Markdown("## π― Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### πΌοΈ ResNet Item Embedder") | |
| # Model architecture | |
| resnet_backbone = gr.Dropdown( | |
| choices=["resnet50", "resnet101"], | |
| value="resnet50", | |
| label="Backbone Architecture" | |
| ) | |
| resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") | |
| resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained") | |
| resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") | |
| # Training parameters | |
| resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs") | |
| resnet_batch_size = gr.Slider(8, 128, value=64, step=8, label="Batch Size") | |
| resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate") | |
| resnet_optimizer = gr.Dropdown( | |
| choices=["adamw", "adam", "sgd", "rmsprop"], | |
| value="adamw", | |
| label="Optimizer" | |
| ) | |
| resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay") | |
| resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin") | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π§ ViT Outfit Encoder") | |
| # Model architecture | |
| vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension") | |
| vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers") | |
| vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads") | |
| vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier") | |
| vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate") | |
| # Training parameters | |
| vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs") | |
| vit_batch_size = gr.Slider(4, 64, value=32, step=4, label="Batch Size") | |
| vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate") | |
| vit_optimizer = gr.Dropdown( | |
| choices=["adamw", "adam", "sgd", "rmsprop"], | |
| value="adamw", | |
| label="Optimizer" | |
| ) | |
| vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay") | |
| vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### βοΈ Advanced Training Settings") | |
| # Hardware optimization | |
| use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)") | |
| channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format") | |
| gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping") | |
| # Learning rate scheduling | |
| warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs") | |
| scheduler_type = gr.Dropdown( | |
| choices=["cosine", "step", "plateau", "linear"], | |
| value="cosine", | |
| label="Learning Rate Scheduler" | |
| ) | |
| early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience") | |
| # Training strategy | |
| mining_strategy = gr.Dropdown( | |
| choices=["semi_hard", "hardest", "random"], | |
| value="semi_hard", | |
| label="Triplet Mining Strategy" | |
| ) | |
| augmentation_level = gr.Dropdown( | |
| choices=["minimal", "standard", "aggressive"], | |
| value="standard", | |
| label="Data Augmentation Level" | |
| ) | |
| seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed") | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Training Control") | |
| # Quick training | |
| gr.Markdown("**Quick Training (Basic Parameters)**") | |
| epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs") | |
| epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs") | |
| start_btn = gr.Button("π Start Quick Training", variant="secondary") | |
| # Advanced training | |
| gr.Markdown("**Advanced Training (Custom Parameters)**") | |
| start_advanced_btn = gr.Button("π― Start Advanced Training", variant="primary") | |
| # Training log | |
| train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20) | |
| # Status | |
| gr.Markdown("**Training Status**") | |
| training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False) | |
| return training_interface, { | |
| 'resnet_backbone': resnet_backbone, | |
| 'resnet_embedding_dim': resnet_embedding_dim, | |
| 'resnet_use_pretrained': resnet_use_pretrained, | |
| 'resnet_dropout': resnet_dropout, | |
| 'resnet_epochs': resnet_epochs, | |
| 'resnet_batch_size': resnet_batch_size, | |
| 'resnet_lr': resnet_lr, | |
| 'resnet_optimizer': resnet_optimizer, | |
| 'resnet_weight_decay': resnet_weight_decay, | |
| 'resnet_triplet_margin': resnet_triplet_margin, | |
| 'vit_embedding_dim': vit_embedding_dim, | |
| 'vit_num_layers': vit_num_layers, | |
| 'vit_num_heads': vit_num_heads, | |
| 'vit_ff_multiplier': vit_ff_multiplier, | |
| 'vit_dropout': vit_dropout, | |
| 'vit_epochs': vit_epochs, | |
| 'vit_batch_size': vit_batch_size, | |
| 'vit_lr': vit_lr, | |
| 'vit_optimizer': vit_optimizer, | |
| 'vit_weight_decay': vit_weight_decay, | |
| 'vit_triplet_margin': vit_triplet_margin, | |
| 'use_mixed_precision': use_mixed_precision, | |
| 'channels_last': channels_last, | |
| 'gradient_clip': gradient_clip, | |
| 'warmup_epochs': warmup_epochs, | |
| 'scheduler_type': scheduler_type, | |
| 'early_stopping_patience': early_stopping_patience, | |
| 'mining_strategy': mining_strategy, | |
| 'augmentation_level': augmentation_level, | |
| 'seed': seed, | |
| 'start_btn': start_btn, | |
| 'start_advanced_btn': start_advanced_btn, | |
| 'train_log': train_log, | |
| 'training_status': training_status | |
| } | |
| def start_advanced_training( | |
| # ResNet parameters | |
| resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str, | |
| resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int, | |
| resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float, | |
| # ViT parameters | |
| vit_epochs: int, vit_batch_size: int, vit_lr: float, vit_optimizer: str, | |
| vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int, | |
| vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float, | |
| # Advanced parameters | |
| use_mixed_precision: bool, channels_last: bool, gradient_clip: float, | |
| warmup_epochs: int, scheduler_type: str, early_stopping_patience: int, | |
| mining_strategy: str, augmentation_level: str, seed: int, | |
| dataset_root: str = None | |
| ): | |
| """Start advanced training with custom parameters.""" | |
| if not dataset_root: | |
| dataset_root = os.getenv("POLYVORE_ROOT", "data/Polyvore") | |
| if not os.path.exists(dataset_root): | |
| return "β Dataset not ready. Please wait for bootstrap to complete." | |
| def _runner(): | |
| try: | |
| import subprocess | |
| import json | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| # Create custom config files | |
| resnet_config = { | |
| "model": { | |
| "backbone": resnet_backbone, | |
| "embedding_dim": resnet_embedding_dim, | |
| "pretrained": resnet_use_pretrained, | |
| "dropout": resnet_dropout | |
| }, | |
| "training": { | |
| "batch_size": resnet_batch_size, | |
| "epochs": resnet_epochs, | |
| "lr": resnet_lr, | |
| "weight_decay": resnet_weight_decay, | |
| "triplet_margin": resnet_triplet_margin, | |
| "optimizer": resnet_optimizer, | |
| "scheduler": scheduler_type, | |
| "warmup_epochs": warmup_epochs, | |
| "early_stopping_patience": early_stopping_patience, | |
| "use_amp": use_mixed_precision, | |
| "channels_last": channels_last, | |
| "gradient_clip": gradient_clip | |
| }, | |
| "data": { | |
| "image_size": 224, | |
| "augmentation_level": augmentation_level | |
| }, | |
| "advanced": { | |
| "mining_strategy": mining_strategy, | |
| "seed": seed | |
| } | |
| } | |
| vit_config = { | |
| "model": { | |
| "embedding_dim": vit_embedding_dim, | |
| "num_layers": vit_num_layers, | |
| "num_heads": vit_num_heads, | |
| "ff_multiplier": vit_ff_multiplier, | |
| "dropout": vit_dropout | |
| }, | |
| "training": { | |
| "batch_size": vit_batch_size, | |
| "epochs": vit_epochs, | |
| "lr": vit_lr, | |
| "weight_decay": vit_weight_decay, | |
| "triplet_margin": vit_triplet_margin, | |
| "optimizer": vit_optimizer, | |
| "scheduler": scheduler_type, | |
| "warmup_epochs": warmup_epochs, | |
| "early_stopping_patience": early_stopping_patience, | |
| "use_amp": use_mixed_precision | |
| }, | |
| "advanced": { | |
| "mining_strategy": mining_strategy, | |
| "seed": seed | |
| } | |
| } | |
| # Save configs | |
| with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f: | |
| json.dump(resnet_config, f, indent=2) | |
| with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f: | |
| json.dump(vit_config, f, indent=2) | |
| # Train ResNet with custom parameters | |
| train_log.value = f"π Starting ResNet training with custom parameters...\n" | |
| train_log.value += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n" | |
| train_log.value += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n" | |
| train_log.value += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n" | |
| resnet_cmd = [ | |
| "python", "training/train_resnet.py", | |
| "--data_root", dataset_root, | |
| "--epochs", str(resnet_epochs), | |
| "--batch_size", str(resnet_batch_size), | |
| "--lr", str(resnet_lr), | |
| "--weight_decay", str(resnet_weight_decay), | |
| "--triplet_margin", str(resnet_triplet_margin), | |
| "--embedding_dim", str(resnet_embedding_dim), | |
| "--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth") | |
| ] | |
| if resnet_backbone != "resnet50": | |
| resnet_cmd.extend(["--backbone", resnet_backbone]) | |
| result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False) | |
| if result.returncode == 0: | |
| train_log.value += "β ResNet training completed successfully!\n\n" | |
| else: | |
| train_log.value += f"β ResNet training failed: {result.stderr}\n\n" | |
| return | |
| # Train ViT with custom parameters | |
| train_log.value += f"π Starting ViT training with custom parameters...\n" | |
| train_log.value += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n" | |
| train_log.value += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n" | |
| train_log.value += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n" | |
| vit_cmd = [ | |
| "python", "training/train_vit.py", | |
| "--data_root", dataset_root, | |
| "--epochs", str(vit_epochs), | |
| "--batch_size", str(vit_batch_size), | |
| "--lr", str(vit_lr), | |
| "--weight_decay", str(vit_weight_decay), | |
| "--triplet_margin", str(vit_triplet_margin), | |
| "--embedding_dim", str(vit_embedding_dim), | |
| "--export", os.path.join(export_dir, "vit_outfit_model_custom.pth") | |
| ] | |
| result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False) | |
| if result.returncode == 0: | |
| train_log.value += "β ViT training completed successfully!\n\n" | |
| train_log.value += "π All training completed! Models saved to models/exports/\n" | |
| train_log.value += "π Reloading models for inference...\n" | |
| # Note: service.reload_models() would need to be called from main app | |
| train_log.value += "β Models reloaded and ready for inference!\n" | |
| else: | |
| train_log.value += f"β ViT training failed: {result.stderr}\n" | |
| except Exception as e: | |
| train_log.value += f"\nβ Training error: {str(e)}" | |
| threading.Thread(target=_runner, daemon=True).start() | |
| return "π Advanced training started with custom parameters! Check the log below for progress." | |
| def start_simple_training(res_epochs: int, vit_epochs: int, dataset_root: str = None): | |
| """Start simple training with basic parameters.""" | |
| if not dataset_root: | |
| dataset_root = os.getenv("POLYVORE_ROOT", "data/Polyvore") | |
| def _runner(): | |
| try: | |
| import subprocess | |
| if not os.path.exists(dataset_root): | |
| train_log.value = "Dataset not ready." | |
| return | |
| export_dir = os.getenv("EXPORT_DIR", "models/exports") | |
| os.makedirs(export_dir, exist_ok=True) | |
| train_log.value = "Training ResNetβ¦\n" | |
| subprocess.run([ | |
| "python", "training/train_resnet.py", "--data_root", dataset_root, "--epochs", str(res_epochs), | |
| "--out", os.path.join(export_dir, "resnet_item_embedder.pth") | |
| ], check=False) | |
| train_log.value += "\nTraining ViT (triplet)β¦\n" | |
| subprocess.run([ | |
| "python", "training/train_vit.py", "--data_root", dataset_root, "--epochs", str(vit_epochs), | |
| "--export", os.path.join(export_dir, "vit_outfit_model.pth") | |
| ], check=False) | |
| train_log.value += "\nDone. Artifacts in models/exports." | |
| except Exception as e: | |
| train_log.value += f"\nError: {e}" | |
| threading.Thread(target=_runner, daemon=True).start() | |
| return "Started" | |
| # Example usage | |
| if __name__ == "__main__": | |
| interface, components = create_advanced_training_interface() | |
| # Set up event handlers | |
| components['start_btn'].click( | |
| fn=start_simple_training, | |
| inputs=[components['resnet_epochs'], components['vit_epochs']], | |
| outputs=components['train_log'] | |
| ) | |
| components['start_advanced_btn'].click( | |
| fn=start_advanced_training, | |
| inputs=[ | |
| components['resnet_epochs'], components['resnet_batch_size'], components['resnet_lr'], | |
| components['resnet_optimizer'], components['resnet_weight_decay'], components['resnet_triplet_margin'], | |
| components['resnet_embedding_dim'], components['resnet_backbone'], components['resnet_use_pretrained'], | |
| components['resnet_dropout'], components['vit_epochs'], components['vit_batch_size'], components['vit_lr'], | |
| components['vit_optimizer'], components['vit_weight_decay'], components['vit_triplet_margin'], | |
| components['vit_embedding_dim'], components['vit_num_layers'], components['vit_num_heads'], | |
| components['vit_ff_multiplier'], components['vit_dropout'], components['use_mixed_precision'], | |
| components['channels_last'], components['gradient_clip'], components['warmup_epochs'], | |
| components['scheduler_type'], components['early_stopping_patience'], components['mining_strategy'], | |
| components['augmentation_level'], components['seed'] | |
| ], | |
| outputs=components['train_log'] | |
| ) | |
| interface.launch() | |