recomendation / ui /advanced_training.py
Ali Mohsin
folder reorganise
72af8c3
"""
Advanced Training UI Components for Dressify
Provides comprehensive parameter controls for both ResNet and ViT training
"""
import gradio as gr
import os
import subprocess
import threading
import json
from typing import Dict, Any
def create_advanced_training_interface():
"""Create the advanced training interface with all parameter controls."""
with gr.Blocks(title="Advanced Training Control") as training_interface:
gr.Markdown("## 🎯 Comprehensive Training Parameter Control\nCustomize every aspect of model training for research and experimentation.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### πŸ–ΌοΈ ResNet Item Embedder")
# Model architecture
resnet_backbone = gr.Dropdown(
choices=["resnet50", "resnet101"],
value="resnet50",
label="Backbone Architecture"
)
resnet_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension")
resnet_use_pretrained = gr.Checkbox(value=True, label="Use ImageNet Pretrained")
resnet_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate")
# Training parameters
resnet_epochs = gr.Slider(1, 100, value=20, step=1, label="Epochs")
resnet_batch_size = gr.Slider(8, 128, value=64, step=8, label="Batch Size")
resnet_lr = gr.Slider(1e-5, 1e-2, value=1e-3, step=1e-5, label="Learning Rate")
resnet_optimizer = gr.Dropdown(
choices=["adamw", "adam", "sgd", "rmsprop"],
value="adamw",
label="Optimizer"
)
resnet_weight_decay = gr.Slider(1e-6, 1e-2, value=1e-4, step=1e-6, label="Weight Decay")
resnet_triplet_margin = gr.Slider(0.1, 1.0, value=0.2, step=0.05, label="Triplet Margin")
with gr.Column(scale=1):
gr.Markdown("#### 🧠 ViT Outfit Encoder")
# Model architecture
vit_embedding_dim = gr.Slider(128, 1024, value=512, step=128, label="Embedding Dimension")
vit_num_layers = gr.Slider(2, 12, value=6, step=1, label="Transformer Layers")
vit_num_heads = gr.Slider(4, 16, value=8, step=2, label="Attention Heads")
vit_ff_multiplier = gr.Slider(2, 8, value=4, step=1, label="Feed-Forward Multiplier")
vit_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.05, label="Dropout Rate")
# Training parameters
vit_epochs = gr.Slider(1, 100, value=30, step=1, label="Epochs")
vit_batch_size = gr.Slider(4, 64, value=32, step=4, label="Batch Size")
vit_lr = gr.Slider(1e-5, 1e-2, value=5e-4, step=1e-5, label="Learning Rate")
vit_optimizer = gr.Dropdown(
choices=["adamw", "adam", "sgd", "rmsprop"],
value="adamw",
label="Optimizer"
)
vit_weight_decay = gr.Slider(1e-4, 1e-1, value=5e-2, step=1e-4, label="Weight Decay")
vit_triplet_margin = gr.Slider(0.1, 1.0, value=0.3, step=0.05, label="Triplet Margin")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### βš™οΈ Advanced Training Settings")
# Hardware optimization
use_mixed_precision = gr.Checkbox(value=True, label="Mixed Precision (AMP)")
channels_last = gr.Checkbox(value=True, label="Channels Last Memory Format")
gradient_clip = gr.Slider(0.1, 5.0, value=1.0, step=0.1, label="Gradient Clipping")
# Learning rate scheduling
warmup_epochs = gr.Slider(0, 10, value=3, step=1, label="Warmup Epochs")
scheduler_type = gr.Dropdown(
choices=["cosine", "step", "plateau", "linear"],
value="cosine",
label="Learning Rate Scheduler"
)
early_stopping_patience = gr.Slider(5, 20, value=10, step=1, label="Early Stopping Patience")
# Training strategy
mining_strategy = gr.Dropdown(
choices=["semi_hard", "hardest", "random"],
value="semi_hard",
label="Triplet Mining Strategy"
)
augmentation_level = gr.Dropdown(
choices=["minimal", "standard", "aggressive"],
value="standard",
label="Data Augmentation Level"
)
seed = gr.Slider(0, 9999, value=42, step=1, label="Random Seed")
with gr.Column(scale=1):
gr.Markdown("#### πŸš€ Training Control")
# Quick training
gr.Markdown("**Quick Training (Basic Parameters)**")
epochs_res = gr.Slider(1, 50, value=10, step=1, label="ResNet epochs")
epochs_vit = gr.Slider(1, 100, value=20, step=1, label="ViT epochs")
start_btn = gr.Button("πŸš€ Start Quick Training", variant="secondary")
# Advanced training
gr.Markdown("**Advanced Training (Custom Parameters)**")
start_advanced_btn = gr.Button("🎯 Start Advanced Training", variant="primary")
# Training log
train_log = gr.Textbox(label="Training Log", lines=15, max_lines=20)
# Status
gr.Markdown("**Training Status**")
training_status = gr.Textbox(label="Status", value="Ready to train", interactive=False)
return training_interface, {
'resnet_backbone': resnet_backbone,
'resnet_embedding_dim': resnet_embedding_dim,
'resnet_use_pretrained': resnet_use_pretrained,
'resnet_dropout': resnet_dropout,
'resnet_epochs': resnet_epochs,
'resnet_batch_size': resnet_batch_size,
'resnet_lr': resnet_lr,
'resnet_optimizer': resnet_optimizer,
'resnet_weight_decay': resnet_weight_decay,
'resnet_triplet_margin': resnet_triplet_margin,
'vit_embedding_dim': vit_embedding_dim,
'vit_num_layers': vit_num_layers,
'vit_num_heads': vit_num_heads,
'vit_ff_multiplier': vit_ff_multiplier,
'vit_dropout': vit_dropout,
'vit_epochs': vit_epochs,
'vit_batch_size': vit_batch_size,
'vit_lr': vit_lr,
'vit_optimizer': vit_optimizer,
'vit_weight_decay': vit_weight_decay,
'vit_triplet_margin': vit_triplet_margin,
'use_mixed_precision': use_mixed_precision,
'channels_last': channels_last,
'gradient_clip': gradient_clip,
'warmup_epochs': warmup_epochs,
'scheduler_type': scheduler_type,
'early_stopping_patience': early_stopping_patience,
'mining_strategy': mining_strategy,
'augmentation_level': augmentation_level,
'seed': seed,
'start_btn': start_btn,
'start_advanced_btn': start_advanced_btn,
'train_log': train_log,
'training_status': training_status
}
def start_advanced_training(
# ResNet parameters
resnet_epochs: int, resnet_batch_size: int, resnet_lr: float, resnet_optimizer: str,
resnet_weight_decay: float, resnet_triplet_margin: float, resnet_embedding_dim: int,
resnet_backbone: str, resnet_use_pretrained: bool, resnet_dropout: float,
# ViT parameters
vit_epochs: int, vit_batch_size: int, vit_lr: float, vit_optimizer: str,
vit_weight_decay: float, vit_triplet_margin: float, vit_embedding_dim: int,
vit_num_layers: int, vit_num_heads: int, vit_ff_multiplier: int, vit_dropout: float,
# Advanced parameters
use_mixed_precision: bool, channels_last: bool, gradient_clip: float,
warmup_epochs: int, scheduler_type: str, early_stopping_patience: int,
mining_strategy: str, augmentation_level: str, seed: int,
dataset_root: str = None
):
"""Start advanced training with custom parameters."""
if not dataset_root:
dataset_root = os.getenv("POLYVORE_ROOT", "data/Polyvore")
if not os.path.exists(dataset_root):
return "❌ Dataset not ready. Please wait for bootstrap to complete."
def _runner():
try:
import subprocess
import json
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
# Create custom config files
resnet_config = {
"model": {
"backbone": resnet_backbone,
"embedding_dim": resnet_embedding_dim,
"pretrained": resnet_use_pretrained,
"dropout": resnet_dropout
},
"training": {
"batch_size": resnet_batch_size,
"epochs": resnet_epochs,
"lr": resnet_lr,
"weight_decay": resnet_weight_decay,
"triplet_margin": resnet_triplet_margin,
"optimizer": resnet_optimizer,
"scheduler": scheduler_type,
"warmup_epochs": warmup_epochs,
"early_stopping_patience": early_stopping_patience,
"use_amp": use_mixed_precision,
"channels_last": channels_last,
"gradient_clip": gradient_clip
},
"data": {
"image_size": 224,
"augmentation_level": augmentation_level
},
"advanced": {
"mining_strategy": mining_strategy,
"seed": seed
}
}
vit_config = {
"model": {
"embedding_dim": vit_embedding_dim,
"num_layers": vit_num_layers,
"num_heads": vit_num_heads,
"ff_multiplier": vit_ff_multiplier,
"dropout": vit_dropout
},
"training": {
"batch_size": vit_batch_size,
"epochs": vit_epochs,
"lr": vit_lr,
"weight_decay": vit_weight_decay,
"triplet_margin": vit_triplet_margin,
"optimizer": vit_optimizer,
"scheduler": scheduler_type,
"warmup_epochs": warmup_epochs,
"early_stopping_patience": early_stopping_patience,
"use_amp": use_mixed_precision
},
"advanced": {
"mining_strategy": mining_strategy,
"seed": seed
}
}
# Save configs
with open(os.path.join(export_dir, "resnet_config_custom.json"), "w") as f:
json.dump(resnet_config, f, indent=2)
with open(os.path.join(export_dir, "vit_config_custom.json"), "w") as f:
json.dump(vit_config, f, indent=2)
# Train ResNet with custom parameters
train_log.value = f"πŸš€ Starting ResNet training with custom parameters...\n"
train_log.value += f"Backbone: {resnet_backbone}, Embedding Dim: {resnet_embedding_dim}\n"
train_log.value += f"Epochs: {resnet_epochs}, Batch Size: {resnet_batch_size}, LR: {resnet_lr}\n"
train_log.value += f"Optimizer: {resnet_optimizer}, Triplet Margin: {resnet_triplet_margin}\n"
resnet_cmd = [
"python", "training/train_resnet.py",
"--data_root", dataset_root,
"--epochs", str(resnet_epochs),
"--batch_size", str(resnet_batch_size),
"--lr", str(resnet_lr),
"--weight_decay", str(resnet_weight_decay),
"--triplet_margin", str(resnet_triplet_margin),
"--embedding_dim", str(resnet_embedding_dim),
"--out", os.path.join(export_dir, "resnet_item_embedder_custom.pth")
]
if resnet_backbone != "resnet50":
resnet_cmd.extend(["--backbone", resnet_backbone])
result = subprocess.run(resnet_cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
train_log.value += "βœ… ResNet training completed successfully!\n\n"
else:
train_log.value += f"❌ ResNet training failed: {result.stderr}\n\n"
return
# Train ViT with custom parameters
train_log.value += f"πŸš€ Starting ViT training with custom parameters...\n"
train_log.value += f"Layers: {vit_num_layers}, Heads: {vit_num_heads}, FF Multiplier: {vit_ff_multiplier}\n"
train_log.value += f"Epochs: {vit_epochs}, Batch Size: {vit_batch_size}, LR: {vit_lr}\n"
train_log.value += f"Optimizer: {vit_optimizer}, Triplet Margin: {vit_triplet_margin}\n"
vit_cmd = [
"python", "training/train_vit.py",
"--data_root", dataset_root,
"--epochs", str(vit_epochs),
"--batch_size", str(vit_batch_size),
"--lr", str(vit_lr),
"--weight_decay", str(vit_weight_decay),
"--triplet_margin", str(vit_triplet_margin),
"--embedding_dim", str(vit_embedding_dim),
"--export", os.path.join(export_dir, "vit_outfit_model_custom.pth")
]
result = subprocess.run(vit_cmd, capture_output=True, text=True, check=False)
if result.returncode == 0:
train_log.value += "βœ… ViT training completed successfully!\n\n"
train_log.value += "πŸŽ‰ All training completed! Models saved to models/exports/\n"
train_log.value += "πŸ”„ Reloading models for inference...\n"
# Note: service.reload_models() would need to be called from main app
train_log.value += "βœ… Models reloaded and ready for inference!\n"
else:
train_log.value += f"❌ ViT training failed: {result.stderr}\n"
except Exception as e:
train_log.value += f"\n❌ Training error: {str(e)}"
threading.Thread(target=_runner, daemon=True).start()
return "πŸš€ Advanced training started with custom parameters! Check the log below for progress."
def start_simple_training(res_epochs: int, vit_epochs: int, dataset_root: str = None):
"""Start simple training with basic parameters."""
if not dataset_root:
dataset_root = os.getenv("POLYVORE_ROOT", "data/Polyvore")
def _runner():
try:
import subprocess
if not os.path.exists(dataset_root):
train_log.value = "Dataset not ready."
return
export_dir = os.getenv("EXPORT_DIR", "models/exports")
os.makedirs(export_dir, exist_ok=True)
train_log.value = "Training ResNet…\n"
subprocess.run([
"python", "training/train_resnet.py", "--data_root", dataset_root, "--epochs", str(res_epochs),
"--out", os.path.join(export_dir, "resnet_item_embedder.pth")
], check=False)
train_log.value += "\nTraining ViT (triplet)…\n"
subprocess.run([
"python", "training/train_vit.py", "--data_root", dataset_root, "--epochs", str(vit_epochs),
"--export", os.path.join(export_dir, "vit_outfit_model.pth")
], check=False)
train_log.value += "\nDone. Artifacts in models/exports."
except Exception as e:
train_log.value += f"\nError: {e}"
threading.Thread(target=_runner, daemon=True).start()
return "Started"
# Example usage
if __name__ == "__main__":
interface, components = create_advanced_training_interface()
# Set up event handlers
components['start_btn'].click(
fn=start_simple_training,
inputs=[components['resnet_epochs'], components['vit_epochs']],
outputs=components['train_log']
)
components['start_advanced_btn'].click(
fn=start_advanced_training,
inputs=[
components['resnet_epochs'], components['resnet_batch_size'], components['resnet_lr'],
components['resnet_optimizer'], components['resnet_weight_decay'], components['resnet_triplet_margin'],
components['resnet_embedding_dim'], components['resnet_backbone'], components['resnet_use_pretrained'],
components['resnet_dropout'], components['vit_epochs'], components['vit_batch_size'], components['vit_lr'],
components['vit_optimizer'], components['vit_weight_decay'], components['vit_triplet_margin'],
components['vit_embedding_dim'], components['vit_num_layers'], components['vit_num_heads'],
components['vit_ff_multiplier'], components['vit_dropout'], components['use_mixed_precision'],
components['channels_last'], components['gradient_clip'], components['warmup_epochs'],
components['scheduler_type'], components['early_stopping_patience'], components['mining_strategy'],
components['augmentation_level'], components['seed']
],
outputs=components['train_log']
)
interface.launch()