Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,237 Bytes
4724018 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, List
@dataclass
class TrainingConfig:
image_size: int
# train_batch_size = 16
# eval_batch_size = 16 # how many images to sample during evaluation
# num_epochs = 50
# gradient_accumulation_steps = 1
# learning_rate = 1e-4
# lr_warmup_steps = 500
# save_image_epochs = 10
# save_model_epochs = 30
# mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
# output_dir = "ddpm-butterflies-128" # the model name locally and on the HF Hub
# logging
output_dir: str
logging_dir: str
vis_dir: str
report_to: Optional[str]
local_rank: int
tracker_project_name: str
# Training
seed: Optional[int]
train_batch_size: int
eval_batch_size: int
num_train_epochs: int
max_train_steps: int
gradient_accumulation_steps: int
gradient_checkpointing: bool
learning_rate: float
scale_lr: bool
lr_scheduler: str
lr_warmup_steps: int
use_8bit_adam: bool
allow_tf32: bool
dataloader_num_workers: int
adam_beta1: float
adam_beta2: float
adam_weight_decay: float
adam_epsilon: float
max_grad_norm: Optional[float]
prediction_type: Optional[str]
mixed_precision: Optional[str]
checkpointing_steps: int
checkpoints_total_limit: Optional[int]
resume_from_checkpoint: Optional[str]
enable_xformers_memory_efficient_attention: bool
validation_steps: int
validation_train_steps: int
validation_sanity_check: bool
resume_step: Optional[int]
push_to_hub: bool
set_grads_to_none: bool
lambda_vel: float
lambda_mask : float
lambda_momentum: float
lambda_deform: float
overfit: bool
# Diffusion Specific
condition_drop_rate: float
# Dataset
train_dataset: Dict
# Model
model_type: str
pred_offset: bool
model_config: Dict
pc_size: int
@dataclass
class TestingConfig:
dataloader_num_workers: int
pc_size: int
model_type: str
pred_offset: bool
model_config: Dict
train_dataset: Dict
resume: str
vis_dir: str
eval_batch_size: int
seed: int
num_inference_steps: int |