tite-2-late-upscale-msmarco / pl_config.yaml
fschlatt's picture
Upload folder using huggingface_hub
86f4284 verified
# lightning.pytorch==2.5.2
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: bf16-mixed
logger:
class_path: lightning_ir.LightningIRWandbLogger
init_args:
name: null
save_dir: .
version: null
offline: false
dir: null
id: null
anonymous: null
project: tite
log_model: false
experiment: null
prefix: ''
checkpoint_name: null
entity: tite
notes: null
tags: null
config: null
config_exclude_keys: null
config_include_keys: null
allow_val_change: null
group: null
job_type: null
mode: null
force: null
reinit: null
resume: null
resume_from: null
fork_from: null
save_code: null
tensorboard: null
sync_tensorboard: null
monitor_gym: null
settings: null
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
dirpath: null
filename: null
monitor: null
verbose: false
save_last: null
save_top_k: 1
save_weights_only: false
mode: min
auto_insert_metric_name: true
every_n_train_steps: null
train_time_interval: null
every_n_epochs: null
save_on_train_epoch_end: null
enable_version_counter: true
- class_path: tite.utils.callbacks.DummyImportCallback
fast_dev_run: false
max_epochs: null
min_epochs: null
max_steps: 10100
min_steps: null
max_time: null
limit_train_batches: null
limit_val_batches: null
limit_test_batches: null
limit_predict_batches: null
overfit_batches: 0.0
val_check_interval: 20000
check_val_every_n_epoch: null
num_sanity_val_steps: null
log_every_n_steps: null
enable_checkpointing: null
enable_progress_bar: false
enable_model_summary: null
accumulate_grad_batches: 8
gradient_clip_val: 1
gradient_clip_algorithm: null
deterministic: null
benchmark: null
inference_mode: true
use_distributed_sampler: true
profiler: null
detect_anomaly: false
barebones: false
plugins: null
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: null
model_registry: null
model:
class_path: lightning_ir.BiEncoderModule
init_args:
model_name_or_path: webis/tite-2-late-upscale
config:
class_path: lightning_ir.models.DprConfig
init_args:
query_length: 32
doc_length: 256
similarity_function: dot
normalize: false
sparsification: null
add_marker_tokens: true
query_pooling_strategy: first
doc_pooling_strategy: first
embedding_dim: 768
projection: null
model: null
loss_functions:
- class_path: lightning_ir.SupervisedMarginMSE
- class_path: lightning_ir.KLDivergence
- class_path: lightning_ir.ScoreBasedInBatchCrossEntropy
init_args:
min_target_diff: 3.0
max_num_neg_samples: null
evaluation_metrics:
- nDCG@10
index_dir: null
search_config: null
model_kwargs: null
data:
class_path: lightning_ir.LightningIRDataModule
init_args:
train_dataset:
class_path: lightning_ir.RunDataset
init_args:
run_path_or_id: msmarco-passage/train/rank-distillm-set-encoder
depth: 100
sample_size: 16
sampling_strategy: log_random
targets: score
normalize_targets: false
add_docs_not_in_ranking: false
train_batch_size: 64
shuffle_train: true
inference_batch_size: 1
num_workers: 4
lr_scheduler:
class_path: tite.utils.lr_schedulers.ConstantLRSchedulerWithLinearWarmup
init_args:
num_warmup_steps: 3000
num_delay_steps: 0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 5.0e-05
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.01
amsgrad: false
maximize: false
foreach: null
capturable: false
differentiable: false
fused: null
ckpt_path: null