UnReflectAnything / configs /highlight_decoder_pretrain.yaml
AlbeRota's picture
Upload weights, notebooks, sample images
10a2918 verified
### BASELINE: CONVERGES AFTER LONG
parameters:
### MODEL ARCHITECTURE
MODEL:
value:
MODEL_CLASS: "UnReflect_Model" # Main model class name (must match class in models.py) # <<<<<<<<< DECODER PRETRAINING: NOT USING TOKEN INPAINTER (DIRECT FROM DINO)
MODEL_MODULE: "models" # Module name to import model classes from (default: "models")
RGB_ENCODER:
ENCODER: "facebook/dinov3-vitl16-pretrain-lvd1689m" # DINOv3 encoder model name (HuggingFace format)
IMAGE_SIZE: 896 # Input image size (height and width in pixels)
RETURN_SELECTED_LAYERS: [3, 6, 9, 12] # Transformer layer indices to extract features from (0-indexed)
RGB_ENCODER_LR: 0.0 # Learning rate for RGB encoder (0.0 = frozen, must be explicitly set)
DECODERS:
highlight:
FEATURE_DIM: 1024 # Feature dimension for highlight decoder
REASSEMBLE_OUT_CHANNELS: [96,192,384,768] # Output channels for each decoder stage
REASSEMBLE_FACTORS: [4.0, 2.0, 1.0, 0.5] # Spatial upsampling factors for each stage
READOUT_TYPE: "ignore" # Readout type for DPT decoder
# FROM_PRETRAINED: "highlight_decoder.pt" # Path to pretrained token inpainter weights (optional)
USE_BN: False # Use batch normalization in decoder
DROPOUT: 0.1 # Dropout rate in decoder layers
OUTPUT_IMAGE_SIZE: [896,896] # Output image resolution [height, width]
OUTPUT_CHANNELS: 1 # Number of output channels (1 for highlight mask)
DECODER_LR: 1.0e-6 # Custom learning rate for decoder (0.0 = frozen, 1.0 = same as base LR)
NUM_FUSION_BLOCKS_TRAINABLE: null # Number of fusion blocks to train (0-4, null = train all if DECODER_LR != 0)
TOKEN_INPAINTER: # <<<<<<<<<<<< DOESNT MATTER, MODEL CLASS IS NOT TOKEN INPAINTER
TOKEN_INPAINTER_CLASS: "TokenInpainter_Prior" # Token inpainter class name
TOKEN_INPAINTER_MODULE: "token_inpainters" # Module name to import token inpainter from
FROM_PRETRAINED: "token_inpainter.pth" # Path to pretrained token inpainter weights (optional)
TOKEN_INPAINTER_LR: 1.0e-5 # Learning rate for token inpainter (can differ from base LR)
DEPTH: 6 # Number of transformer blocks
HEADS: 16 # Number of attention heads
DROP: 0 # Dropout rate
USE_POSITIONAL_ENCODING: True # Enable 2D sinusoidal positional encodings
USE_FINAL_NORM: True # Enable final LayerNorm before output projection
USE_LOCAL_PRIOR: True # Blend local mean prior for masked seeds
LOCAL_PRIOR_WEIGHT: 0.5 # Weight for local prior blending (1.0 = only mask_token, 0.0 = only local mean)
LOCAL_PRIOR_KERNEL: 5 # Kernel size for local prior blending (> 1)
SEED_NOISE_STD: 0.02 # Standard deviation of noise added to masked seeds during training
INPAINT_MASK_DILATION:
value: 3 # Dilation kernel size (pixels) for inpaint mask - Must be odd
USE_TORCH_COMPILE: # Enable PyTorch 2.0 torch.compile for faster training (experimental)
value: False
DISTRIBUTE:
value: "ddp"
### DATA
DATASETS:
value:
SHIQ:
VAL_SCENES: ["test"]
RESIZE_MODE: "resize+crop"
TARGET_SIZE: [896,896]
SAMPLE_EVERY_N: 2
SCARED:
VAL_SCENES: ["v22","v23","v24","v25","v26","v27","v28","v29","v30","v31","v32","v33","v34"] # Validation scene names
RESIZE_MODE: "resize+crop" # Image resizing mode
SAMPLE_EVERY_N: 4 # Load every Nth frame
ALL_DATASETS:
FEW_IMAGES: False
TARGET_SIZE: [896,896]
LOAD_RGB_ONLY: True
LOAD_HIGHLIGHT: True
BATCH_SIZE: # Max batch size with img size 896 is 32
value: 20 # Number of samples per batch (adjust based on GPU memory)
NUM_WORKERS:
value: 12 # Number of data loading worker processes (0 = main process only, "auto" = 90% of CPU affinity)
SHUFFLE:
value: True # Shuffle training data each epoch (False for validation/test)
PIN_MEMORY:
value: True # Pin memory in DataLoader for faster GPU transfer (recommended: True)
PREFETCH_FACTOR:
value: 2 # Number of batches to prefetch per worker (higher = more memory usage)
### HIGHLIGHTS
MOGE_MODEL:
value: "Ruicheng/moge-2-vits-normal" # MoGe model name for normal estimation (HuggingFace format)
SURFACE_ROUGHNESS:
value: 100.0 # Blinn-Phong surface roughness exponent (higher = sharper highlights)
INTENSITY:
value: 0.8 # Specular highlight intensity multiplier
LIGHT_DISTANCE_RANGE:
value: [0.0, 1] # Range for light source distance sampling [min, max] (normalized)
LIGHT_LEFT_RIGHT_ANGLE:
value: [0, 360] # Range for light source horizontal angle [min, max] in degrees
LIGHT_ABOVE_BELOW_ANGLE:
value: [0, 360] # Range for light source vertical angle [min, max] in degrees
DATASET_HIGHLIGHT_DILATION:
value: 25 #sDilation kernel size (pixels) for dataset highlight masks
DATASET_HIGHLIGHT_THRESHOLD:
value: 0.9 # Brightness/luminance threshold (0-1) for detecting highlights in dataset images
DATASET_HIGHLIGHT_USE_LUMINANCE:
value: True # If True, use perceptually-weighted luminance (0.299*R + 0.587*G + 0.114*B) for dataset highlights; if False, use simple mean brightness
HIGHLIGHT_COLOR:
value: [1.0, 1.0, 1.0] # RGB color for synthetic highlights (normalized 0-1)
CLAMP_RECONSTRUCTION:
value: True # Clamp reconstructed images to [0, 1] range if True
### OPTIMIZATION
EPOCHS:
value: 20 # Maximum number of training epochs<
LEARNING_RATE:
value: 1.0e-4 # Base learning rate for optimizer
WEIGHT_DECAY:
value: 0.0 # L2 regularization weight (0.0 = no weight decay)
GRADIENT_ACCUMULATION_STEPS:
value: 1 # Number of steps to accumulate gradients before optimizer step (1 = no accumulation)
WARMUP:
value: 100 # Number of warmup steps for learning rate schedule (linear warmup from 0 to LR)
GRADIENT_CLIPPING_MAX_NORM:
value: 8 # Maximum gradient norm for clipping (set to -1 to disable clipping)
LR_SCHEDULER:
value:
ONPLATEAU: # ReduceLROnPlateau scheduler (reduces LR when validation metric plateaus)
PATIENCE: 5 # Number of epochs to wait before reducing LR
FACTOR: 0.1 # Factor by which LR is reduced (new_lr = old_lr * factor)
COSINE: # CosineAnnealingLR scheduler (cosine annealing schedule)
N_PERIODS: 1 # Number of cosine periods over training
# STEPWISE: # StepLR scheduler (reduces LR at fixed step intervals)
# N_STEPS: 5 # Number of times to reduce LR during training
# GAMMA: 0.25 # Factor by which LR is reduced at each step (new_lr = old_lr * gamma)
# EXPONENTIAL: # ExponentialLR scheduler (exponential decay)
# GAMMA: 0.5 # Multiplicative factor for exponential decay
SWITCH_OPTIMIZER_EPOCH:
value: null # Epoch number to switch from bootstrap to refining optimizer (null = no switch)
OPTIMIZER_BOOTSTRAP_NAME:
value: "AdamW" # Optimizer name for initial training phase ("Adam", "SGD", etc.)
OPTIMIZER_REFINING_NAME:
value: "AdamW" # Optimizer name for refining phase (used after SWITCH_OPTIMIZER_EPOCH)
EARLY_STOPPING_PATIENCE:
value: 20 # Number of epochs without improvement before stopping training
SAVE_INTERVAL:
value: 1000 # Number of training steps between model checkpoints
DATASET_HIGHLIGHT_SUPERVISION_THRESHOLD:
value: 0.1 # Pixel highlights above this threshold (should be low) are excluded from supervision
### LOSS WEIGHTS (relative to the total loss, NOT NORMALIZED LATER)
SPECULAR_LOSS_WEIGHT:
value: 0.0 # Weight for specular component reconstruction loss
DIFFUSE_LOSS_WEIGHT:
value: 0.0 # Weight for diffuse component reconstruction loss
HIGHLIGHT_LOSS_WEIGHT:
value: 1.0 # Weight for highlight mask regression loss
TOKEN_INPAINT_LOSS_WEIGHT:
value: 0.0 # Weight for token-space inpainting loss (L1 + cosine similarity in feature space)
IMAGE_RECONSTRUCTION_LOSS_WEIGHT:
value: 0.0 # Weight for full image reconstruction loss
SATURATION_RING_LOSS_WEIGHT:
value: 0.0 # Weight for saturation ring consistency loss (around highlight regions)
RING_KERNEL_SIZE:
value: 11 # Kernel size (odd number) for saturation ring dilation around highlights
RING_VAR_WEIGHT:
value: 0.5 # Weight for variance matching in saturation ring loss (vs mean matching)
RING_TEXTURE_WEIGHT:
value: 0.0 # Weight for texture consistency term in saturation ring loss
HLREG_W_L1:
value: 1.0 # Weight for L1 loss in highlight regression
HLREG_USE_CHARB:
value: True # Use Charbonnier loss (smooth L1) instead of standard L1 if True
HLREG_W_DICE:
value: 0.2 # Weight for Dice loss in highlight regression (for mask overlap)
HLREG_W_SSIM:
value: 0.0 # Weight for SSIM loss in highlight regression
HLREG_W_GRAD:
value: 0.0 # Weight for gradient loss in highlight regression
HLREG_W_TV:
value: 0.0 # Weight for total variation loss in highlight regression
HLREG_BALANCE_MODE:
value: "auto" # Class balancing mode for highlight regression: 'none' | 'auto' | 'pos_weight'
HLREG_POS_WEIGHT:
value: 1.0 # Positive class weight (used only if BALANCE_MODE == 'pos_weight')
HLREG_FOCAL_GAMMA:
value: 2.0 # Focal loss gamma parameter (0.0 = standard BCE, 1.0-2.0 helps with gradient vanishing)
WEIGHT_CONTEXT_IDENTITY:
value: 0.0 # LEAVE TO 0.0: Weight for L1 loss on context (non-masked) regions (identity preservation)
WEIGHT_TV_IN_HOLE:
value: 0.0 # LEAVE TO 0.0: Weight for total variation loss inside masked/hole regions
RING_DILATE_KERNEL:
value: 17 # Dilation kernel size (odd number) for creating ring mask around highlights
WEIGHT_SEAM:
value: 0.0 # Weight for gradient matching loss on saturation ring
SEAM_USE_CHARB:
value: True # Use Charbonnier loss instead of L1 in seam loss (smooth L1 for boundary consistency)
SEAM_WEIGHT_GRAD:
value: 0.0 # Weight for gradient matching term inside seam loss (0.0 = disable gradient term)
TOKEN_FEAT_ALPHA:
value: 0.5 # Mixing factor for token feature loss: alpha * L1 + (1-alpha) * (1-cosine_sim)
### DIFFUSE HIGHLIGHT PENALTY
WEIGHT_DIFFUSE_HIGHLIGHT_PENALTY:
value: 0.0 # Weight for penalty loss on highlights in diffuse decoder output (0.0 = disabled)
DIFFUSE_HL_THRESHOLD:
value: 0.8 # Brightness/luminance threshold for detecting highlights in diffuse (0.0-1.0)
DIFFUSE_HL_USE_CHARB:
value: True # Use Charbonnier loss instead of L1 for diffuse highlight penalty
DIFFUSE_HL_PENALTY_MODE:
value: "brightness" # Penalty mode: "brightness" (penalize brightness/luminance above threshold) or "pixel" (penalize RGB values directly)
DIFFUSE_HL_TARGET_BRIGHTNESS:
value: null # Target brightness/luminance for penalized pixels (null = use threshold value)
DIFFUSE_HL_USE_LUMINANCE:
value: True # If True, use perceptually-weighted luminance (0.299*R + 0.587*G + 0.114*B); if False, use simple mean brightness
### LOGGING, RESULTS AND WANDB
LOG_INTERVAL:
value: 1 # Number of training steps between console log outputs
WANDB_LOG_INTERVAL:
value: 1 # Number of training steps between WandB metric logs
IMAGE_LOG_INTERVAL:
value: 5 # Number of training steps between image logging to WandB
NO_WANDB:
value: False # Disable WandB logging if True (useful for local debugging)
MODEL_WATCHER_FREQ_WANDB:
value: 50 # Frequency (in steps) for logging model parameter histograms to WandB
WANDB_ENTITY:
value: "unreflect-anything" # WandB organization/entity name
WANDB_PROJECT:
value: "UnReflectAnything" # WandB project name
NOTES:
value: "" # Notes/description for this training run