| import argparse
|
|
|
|
|
| def get_args_parser() -> argparse.Namespace:
|
| """Create and parse command-line options for HTR-ConvText.
|
|
|
| This keeps all option names and defaults intact, but organizes them into
|
| logical groups with clearer help messages.
|
| """
|
| parser = argparse.ArgumentParser(
|
| description='HTR-ConvText: Leveraging Convolution and Textual Context with Mixed Masking for Handwritten Text Recognition',
|
| add_help=True,
|
| formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| )
|
|
|
|
|
|
|
|
|
| exp = parser.add_argument_group('Experiment & Logging')
|
| exp.add_argument('--out-dir', type=str, default='./output',
|
| help='Root directory to save logs, checkpoints, and outputs')
|
| exp.add_argument('--exp-name', type=str, default='IAM_HTR_ORIGAMI_NET',
|
| help='Experiment name; results go to <out-dir>/<exp-name>')
|
| exp.add_argument('--seed', default=123, type=int,
|
| help='Random seed for reproducibility')
|
| exp.add_argument('--use-wandb', action='store_true', default=False,
|
| help='Log to Weights & Biases; otherwise use TensorBoard')
|
| exp.add_argument('--wandb-project', type=str, default='None',
|
| help='W&B project name (used only if --use-wandb)')
|
| exp.add_argument('--print-iter', default=100, type=int,
|
| help='Iterations between training status prints')
|
| exp.add_argument('--eval-iter', default=1000, type=int,
|
| help='Iterations between validation runs')
|
|
|
|
|
|
|
|
|
| data = parser.add_argument_group('Data & Dataloading')
|
| data.add_argument('--dataset', type=str, choices=['iam', 'read2016', 'lam', 'vnondb'],
|
| help='Dataset choice')
|
| data.add_argument('--data-path', type=str, default='./data/iam/lines/',
|
| help='Root directory containing image/line data')
|
| data.add_argument('--train-data-list', type=str, default='./data/iam/train.ln',
|
| help='Path to training list file (e.g., .ln)')
|
| data.add_argument('--val-data-list', type=str, default='./data/iam/val.ln',
|
| help='Path to validation list file (e.g., .ln)')
|
| data.add_argument('--test-data-list', type=str, default='./data/iam/test.ln',
|
| help='Path to test list file (e.g., .ln)')
|
| data.add_argument('--nb-cls', default=80, type=int,
|
| help='Number of classes. IAM=79+1, READ2016=89+1, LAM=90+1, VNOnDB=161+1')
|
| data.add_argument('--num-workers', default=0, type=int,
|
| help='Dataloader worker processes')
|
| data.add_argument('--img-size', default=[512, 64], type=int, nargs='+',
|
| help='Input image size [W, H]')
|
| data.add_argument('--patch-size', default=[4, 32], type=int, nargs='+',
|
| help='Patch size [W, H] for patch embedding')
|
|
|
|
|
|
|
|
|
| train = parser.add_argument_group('Training Schedule & Optimization')
|
| train.add_argument('--train-bs', default=8, type=int,
|
| help='Training batch size per iteration')
|
| train.add_argument('--accum-steps', default=1, type=int,
|
| help='Gradient accumulation steps; effective batch = train-bs * accum-steps')
|
| train.add_argument('--val-bs', default=1, type=int,
|
| help='Validation/test batch size')
|
| train.add_argument('--total-iter', default=100000, type=int,
|
| help='Total training iterations')
|
| train.add_argument('--warm-up-iter', default=1000, type=int,
|
| help='Warm-up iterations for the optimizer/scheduler')
|
| train.add_argument('--max-lr', default=1e-3, type=float,
|
| help='Peak learning rate')
|
| train.add_argument('--weight-decay', default=5e-1, type=float,
|
| help='Weight decay (L2) regularization')
|
| train.add_argument('--ema-decay', default=0.9999, type=float,
|
| help='Exponential Moving Average (EMA) decay factor for model weights')
|
| train.add_argument('--alpha', default=0, type=float,
|
| help='KL-divergence loss ratio (if applicable)')
|
|
|
|
|
|
|
|
|
| model = parser.add_argument_group('Model & Encoder')
|
| model.add_argument('--model-type', default='ctc', type=str, choices=['ctc', 'encoder_decoder'],
|
| help='Model family to train/use')
|
| model.add_argument('--cos-temp', default=8, type=int,
|
| help='Cosine-similarity classifier temperature')
|
| model.add_argument('--proj', default=8, type=float,
|
| help='Projection dimension or scaling for classifier head')
|
| model.add_argument('--attn-mask-ratio', default=0., type=float,
|
| help='Attention drop-key mask ratio')
|
|
|
|
|
|
|
|
|
| mask = parser.add_argument_group('Masking Strategy')
|
| mask.add_argument('--use-masking', action='store_true', default=False,
|
| help='Enable masking strategy during training')
|
| mask.add_argument('--mask-ratio', default=0.3, type=float,
|
| help='Overall proportion of tokens/patches to mask')
|
| mask.add_argument('--max-span-length', default=4, type=int,
|
| help='Max length for individual span masks')
|
| mask.add_argument('--spacing', default=0, type=int,
|
| help='Minimum spacing between two span masks')
|
|
|
| mask.add_argument('--r-rand', dest='r_rand', default=0.6, type=float,
|
| help='Ratio for random masking in tri-masking schedule')
|
| mask.add_argument('--r-block', dest='r_block', default=0.6, type=float,
|
| help='Ratio for block masking in tri-masking schedule')
|
| mask.add_argument('--block-span', dest='block_span', default=4, type=int,
|
| help='Block span length for block masking')
|
| mask.add_argument('--r-span', dest='r_span', default=0.4, type=float,
|
| help='Ratio for span masking in tri-masking schedule')
|
| mask.add_argument('--max-span', dest='max_span', default=8, type=int,
|
| help='Max span length for span masking')
|
|
|
|
|
|
|
|
|
| aug = parser.add_argument_group('Data Augmentations')
|
| aug.add_argument('--dpi-min-factor', default=0.5, type=float,
|
| help='Minimum scaling factor for DPI-based resize')
|
| aug.add_argument('--dpi-max-factor', default=1.5, type=float,
|
| help='Maximum scaling factor for DPI-based resize')
|
| aug.add_argument('--perspective-low', default=0., type=float,
|
| help='Lower bound for perspective transform magnitude')
|
| aug.add_argument('--perspective-high', default=0.4, type=float,
|
| help='Upper bound for perspective transform magnitude')
|
| aug.add_argument('--elastic-distortion-min-kernel-size', default=3, type=int,
|
| help='Minimum kernel size for elastic distortion grid')
|
| aug.add_argument('--elastic-distortion-max-kernel-size', default=3, type=int,
|
| help='Maximum kernel size for elastic distortion grid')
|
| aug.add_argument('--elastic_distortion-max-magnitude', default=20, type=int,
|
| help='Maximum distortion magnitude for elastic transforms')
|
| aug.add_argument('--elastic-distortion-min-alpha', default=0.5, type=float,
|
| help='Minimum alpha for elastic distortion')
|
| aug.add_argument('--elastic-distortion-max-alpha', default=1, type=float,
|
| help='Maximum alpha for elastic distortion')
|
| aug.add_argument('--elastic-distortion-min-sigma', default=1, type=int,
|
| help='Minimum sigma for Gaussian in elastic distortion')
|
| aug.add_argument('--elastic-distortion-max-sigma', default=10, type=int,
|
| help='Maximum sigma for Gaussian in elastic distortion')
|
| aug.add_argument('--dila-ero-max-kernel', default=3, type=int,
|
| help='Max kernel size for dilation/erosion ops')
|
| aug.add_argument('--dila-ero-iter', default=1, type=int,
|
| help='Iterations for dilation/erosion')
|
| aug.add_argument('--jitter-contrast', default=0.4, type=float,
|
| help='ColorJitter: contrast range')
|
| aug.add_argument('--jitter-brightness', default=0.4, type=float,
|
| help='ColorJitter: brightness range')
|
| aug.add_argument('--jitter-saturation', default=0.4, type=float,
|
| help='ColorJitter: saturation range')
|
| aug.add_argument('--jitter-hue', default=0.2, type=float,
|
| help='ColorJitter: hue range')
|
| aug.add_argument('--blur-min-kernel', default=3, type=int,
|
| help='Minimum kernel size for Gaussian blur')
|
| aug.add_argument('--blur-max-kernel', default=5, type=int,
|
| help='Maximum kernel size for Gaussian blur')
|
| aug.add_argument('--blur-min-sigma', default=3, type=int,
|
| help='Minimum sigma for Gaussian blur')
|
| aug.add_argument('--blur-max-sigma', default=5, type=int,
|
| help='Maximum sigma for Gaussian blur')
|
| aug.add_argument('--sharpen-min-alpha', default=0, type=int,
|
| help='Minimum alpha/mix for sharpening')
|
| aug.add_argument('--sharpen-max-alpha', default=1, type=int,
|
| help='Maximum alpha/mix for sharpening')
|
| aug.add_argument('--sharpen-min-strength', default=0, type=int,
|
| help='Minimum sharpening strength')
|
| aug.add_argument('--sharpen-max-strength', default=1, type=int,
|
| help='Maximum sharpening strength')
|
| aug.add_argument('--zoom-min-h', default=0.8, type=float,
|
| help='Minimum vertical zoom factor')
|
| aug.add_argument('--zoom-max-h', default=1, type=float,
|
| help='Maximum vertical zoom factor')
|
| aug.add_argument('--zoom-min-w', default=0.99, type=float,
|
| help='Minimum horizontal zoom factor')
|
| aug.add_argument('--zoom-max-w', default=1, type=float,
|
| help='Maximum horizontal zoom factor')
|
| aug.add_argument('--proba', default=0.5, type=float,
|
| help='Default probability for applying stochastic augmentations')
|
|
|
|
|
|
|
|
|
| dec = parser.add_argument_group('Decoder & Inference')
|
| dec.add_argument('--decoder-layers', default=6, type=int,
|
| help='Number of Transformer decoder layers')
|
| dec.add_argument('--decoder-heads', default=8, type=int,
|
| help='Number of attention heads in decoder')
|
| dec.add_argument('--max-seq-len', default=256, type=int,
|
| help='Maximum output sequence length')
|
| dec.add_argument('--label-smoothing', default=0.1, type=float,
|
| help='Label-smoothing factor for cross-entropy loss')
|
| dec.add_argument('--beam-size', default=5, type=int,
|
| help='Beam size for beam-search decoding')
|
| dec.add_argument('--generation-method', default='nucleus', type=str,
|
| choices=['greedy', 'nucleus', 'beam_search'],
|
| help='Token generation method for inference')
|
| dec.add_argument('--generation-temperature', default=0.7, type=float,
|
| help='Sampling temperature (used by nucleus/greedy sampling)')
|
| dec.add_argument('--repetition-penalty', default=1.3, type=float,
|
| help='Penalty to discourage token repetition during generation')
|
| dec.add_argument('--top-p', default=0.9, type=float,
|
| help='Top-p threshold for nucleus sampling')
|
|
|
|
|
|
|
|
|
| tcm = parser.add_argument_group('TCM (Textual Context Module)')
|
| tcm.add_argument('--tcm-enable', action='store_true', default=False,
|
| help='Enable Textual Context Module (TCM)')
|
| tcm.add_argument('--tcm-lambda', default=1.0, type=float,
|
| help='TCM loss weight (λ2 in the paper)')
|
| tcm.add_argument('--ctc-lambda', default=0.1, type=float,
|
| help='CTC loss weight (λ1 in the paper)')
|
| tcm.add_argument('--tcm-sub-len', default=5, type=int,
|
| help='TCM context sub-string length')
|
| tcm.add_argument('--tcm-warmup-iters', default=0, type=int,
|
| help='Warm-up iterations before activating TCM (0 = start immediately)')
|
|
|
|
|
|
|
|
|
| ckpt = parser.add_argument_group('Checkpointing & Pretrained Weights')
|
| ckpt.add_argument('--resume', type=str, default=None,
|
| help='Resume training from a checkpoint (alias)')
|
| ckpt.add_argument('--load-model', type=str, default=None,
|
| help='Load a full pretrained model for fine-tuning')
|
| ckpt.add_argument('--load-encoder-only', action='store_true', default=False,
|
| help='Load only encoder weights (transfer learning)')
|
| ckpt.add_argument('--strict-loading', action='store_true', default=True,
|
| help='Use strict key matching when loading weights')
|
|
|
| return parser.parse_args() |