danielle-miller-sayag commited on
Commit
f55a812
·
verified ·
1 Parent(s): 7b4d0f9

initial: weights + modeling code + lean config

Browse files
train.py CHANGED
@@ -4,9 +4,12 @@ import sys
4
  from dataclasses import dataclass
5
  from typing import Dict, List, Optional
6
 
 
7
  import torch
8
- from datasets import load_dataset
 
9
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
 
10
 
11
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
12
  from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel
@@ -27,11 +30,60 @@ class PatientCollator:
27
  }
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class PatientTrainer(Trainer):
31
  def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
32
  outputs = model(**inputs)
33
  return (outputs.loss, outputs) if return_outputs else outputs.loss
34
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def parse_args():
37
  p = argparse.ArgumentParser()
@@ -53,6 +105,7 @@ def parse_args():
53
  p.add_argument("--lr_scheduler_type", default="cosine")
54
  p.add_argument("--patience", type=int, default=5)
55
  p.add_argument("--num_workers", type=int, default=4)
 
56
  p.add_argument("--wandb_project", default=None)
57
  p.add_argument("--run_name", default=None)
58
 
@@ -62,7 +115,10 @@ def parse_args():
62
  def main():
63
  args = parse_args()
64
 
65
- ds = load_dataset(args.dataset_path)
 
 
 
66
  train_ds = ds["train"]
67
  val_ds: Optional[object] = ds.get("validation")
68
 
@@ -108,6 +164,9 @@ def main():
108
  report_to="wandb" if args.wandb_project else "none",
109
  run_name=args.run_name,
110
  dataloader_num_workers=args.num_workers,
 
 
 
111
  remove_unused_columns=False,
112
  )
113
 
@@ -119,6 +178,7 @@ def main():
119
  train_dataset=train_ds,
120
  eval_dataset=val_ds,
121
  data_collator=PatientCollator(),
 
122
  callbacks=callbacks,
123
  )
124
 
 
4
  from dataclasses import dataclass
5
  from typing import Dict, List, Optional
6
 
7
+ import numpy as np
8
  import torch
9
+ from datasets import DatasetDict, load_dataset
10
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
11
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
12
+ from transformers.trainer_utils import EvalPrediction
13
 
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
  from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel
 
30
  }
31
 
32
 
33
+ def _patient_predictions(logits: np.ndarray, entity_ids: np.ndarray):
34
+ """Average softmax probabilities across augmented views, one row per patient."""
35
+ entity_ids = np.asarray(entity_ids).astype(str)
36
+ unique = np.unique(entity_ids)
37
+ agg = []
38
+ for eid in unique:
39
+ views = logits[entity_ids == eid]
40
+ exp = np.exp(views - np.max(views, axis=1, keepdims=True))
41
+ agg.append(np.mean(exp / exp.sum(axis=1, keepdims=True), axis=0))
42
+ return np.array(agg), unique
43
+
44
+
45
+ def _clf_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str) -> Dict[str, float]:
46
+ return {
47
+ f"{prefix}accuracy": accuracy_score(y_true, y_pred),
48
+ f"{prefix}f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0),
49
+ f"{prefix}precision": precision_score(y_true, y_pred, average="macro", zero_division=0),
50
+ f"{prefix}recall": recall_score(y_true, y_pred, average="macro", zero_division=0),
51
+ }
52
+
53
+
54
+ def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
55
+ logits_with_entity = eval_pred.predictions # (N, num_classes + 1)
56
+ logits = logits_with_entity[:, :-1]
57
+ entity_ids = logits_with_entity[:, -1].astype(int)
58
+ labels = eval_pred.label_ids
59
+
60
+ metrics = _clf_metrics(labels, np.argmax(logits, axis=1), "per_view/")
61
+
62
+ patient_preds, unique_entities = _patient_predictions(logits, entity_ids)
63
+ patient_labels = np.array([
64
+ labels[np.where(entity_ids == int(eid))[0][0]]
65
+ for eid in unique_entities
66
+ ])
67
+ metrics.update(_clf_metrics(patient_labels, np.argmax(patient_preds, axis=1), "patient/"))
68
+
69
+ return metrics
70
+
71
+
72
  class PatientTrainer(Trainer):
73
  def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
74
  outputs = model(**inputs)
75
  return (outputs.loss, outputs) if return_outputs else outputs.loss
76
 
77
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
78
+ entity_id = inputs.pop("entity_id")
79
+ loss, logits, labels = super().prediction_step(
80
+ model, inputs, prediction_loss_only, ignore_keys=ignore_keys
81
+ )
82
+ if logits is not None:
83
+ entity_col = entity_id.float().unsqueeze(1).to(logits.device)
84
+ logits = torch.cat([logits, entity_col], dim=1)
85
+ return loss, logits, labels
86
+
87
 
88
  def parse_args():
89
  p = argparse.ArgumentParser()
 
105
  p.add_argument("--lr_scheduler_type", default="cosine")
106
  p.add_argument("--patience", type=int, default=5)
107
  p.add_argument("--num_workers", type=int, default=4)
108
+ p.add_argument("--prefetch_factor", type=int, default=2)
109
  p.add_argument("--wandb_project", default=None)
110
  p.add_argument("--run_name", default=None)
111
 
 
115
  def main():
116
  args = parse_args()
117
 
118
+ if os.path.isdir(args.dataset_path):
119
+ ds = DatasetDict.load_from_disk(args.dataset_path)
120
+ else:
121
+ ds = load_dataset(args.dataset_path, num_proc=args.num_workers)
122
  train_ds = ds["train"]
123
  val_ds: Optional[object] = ds.get("validation")
124
 
 
164
  report_to="wandb" if args.wandb_project else "none",
165
  run_name=args.run_name,
166
  dataloader_num_workers=args.num_workers,
167
+ dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
168
+ dataloader_persistent_workers=args.num_workers > 0,
169
+ dataloader_pin_memory=True,
170
  remove_unused_columns=False,
171
  )
172
 
 
178
  train_dataset=train_ds,
179
  eval_dataset=val_ds,
180
  data_collator=PatientCollator(),
181
+ compute_metrics=compute_metrics if has_val else None,
182
  callbacks=callbacks,
183
  )
184
 
wandb/debug-internal.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-05-03T17:12:14.512438+03:00","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
2
+ {"time":"2026-05-03T17:12:15.049447+03:00","level":"INFO","msg":"stream: created new stream","id":"h9m78x54"}
3
+ {"time":"2026-05-03T17:12:15.049472+03:00","level":"INFO","msg":"stream: started","id":"h9m78x54"}
4
+ {"time":"2026-05-03T17:12:15.049488+03:00","level":"INFO","msg":"writer: Do: started","stream_id":"h9m78x54"}
5
+ {"time":"2026-05-03T17:12:15.049533+03:00","level":"INFO","msg":"sender: started","stream_id":"h9m78x54"}
6
+ {"time":"2026-05-03T17:12:15.049551+03:00","level":"INFO","msg":"handler: started","stream_id":"h9m78x54"}
7
+ {"time":"2026-05-03T17:12:15.531811+03:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
8
+ {"time":"2026-05-03T17:14:51.01985+03:00","level":"INFO","msg":"stream: closing","id":"h9m78x54"}
9
+ {"time":"2026-05-03T17:14:51.643326+03:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
10
+ {"time":"2026-05-03T17:14:51.997995+03:00","level":"INFO","msg":"sender: closed","stream_id":"h9m78x54"}
11
+ {"time":"2026-05-03T17:14:51.998011+03:00","level":"INFO","msg":"handler: closed","stream_id":"h9m78x54"}
12
+ {"time":"2026-05-03T17:14:51.998039+03:00","level":"INFO","msg":"writer: Close: closed","stream_id":"h9m78x54"}
13
+ {"time":"2026-05-03T17:14:51.998606+03:00","level":"INFO","msg":"stream: closed","id":"h9m78x54"}
wandb/debug.log ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-05-03 17:12:13,856 INFO MainThread:63423 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
2
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Configure stats pid to 63423
3
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Loading settings from /Users/daniellemillersayag/.config/wandb/settings
4
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Loading settings from /Users/daniellemillersayag/Documents/vcell/paper/hf-release/wandb/settings
5
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_init.py:setup_run_log_directory():703] Logging user logs to /Users/daniellemillersayag/Documents/vcell/paper/hf-release/wandb/run-20260503_171213-h9m78x54/logs/debug.log
7
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to /Users/daniellemillersayag/Documents/vcell/paper/hf-release/wandb/run-20260503_171213-h9m78x54/logs/debug-internal.log
8
+ 2026-05-03 17:12:13,858 INFO MainThread:63423 [wandb_init.py:init():830] calling init triggers
9
+ 2026-05-03 17:12:13,858 INFO MainThread:63423 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {'code_path': 'code/train.py'}}
11
+ 2026-05-03 17:12:13,858 INFO MainThread:63423 [wandb_init.py:init():871] starting backend
12
+ 2026-05-03 17:12:14,495 INFO MainThread:63423 [wandb_init.py:init():874] sending inform_init request
13
+ 2026-05-03 17:12:14,511 INFO MainThread:63423 [wandb_init.py:init():882] backend started and connected
14
+ 2026-05-03 17:12:14,513 INFO MainThread:63423 [wandb_init.py:init():953] updated telemetry
15
+ 2026-05-03 17:12:14,513 INFO MainThread:63423 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2026-05-03 17:12:15,529 INFO MainThread:63423 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_console_start():2458] atexit reg
18
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_redirect():2306] redirect: wrap_raw
19
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_redirect():2375] Wrapping output streams.
20
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_redirect():2398] Redirects installed.
21
+ 2026-05-03 17:12:15,652 INFO MainThread:63423 [wandb_init.py:init():1075] run started, returning control to user process
22
+ 2026-05-03 17:12:15,653 INFO MainThread:63423 [wandb_run.py:_config_callback():1363] config_cb None None {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['VirtualCellPatientModel'], 'finetuning_task': None, 'id2label': {0: 'oncological', 1: 'immune_inflammatory', 2: 'neurological', 3: 'metabolic_vascular', 4: 'gastrointestinal', 5: 'respiratory', 6: 'epithelial_barrier', 7: 'sensory_specialized', 8: 'healthy_control', 9: 'other'}, 'label2id': {'oncological': 0, 'immune_inflammatory': 1, 'neurological': 2, 'metabolic_vascular': 3, 'gastrointestinal': 4, 'respiratory': 5, 'epithelial_barrier': 6, 'sensory_specialized': 7, 'healthy_control': 8, 'other': 9}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '/Users/daniellemillersayag/Documents/vcell/paper/hf-release', '_attn_implementation_autoset': True, 'transformers_version': '4.51.3', 'model_type': 'virtual_cell_patient', 'auto_map': {'AutoConfig': 'modeling_virtual_cell.VirtualCellPatientConfig', 'AutoModel': 'modeling_virtual_cell.VirtualCellPatientModel'}, 'n_genes': 18301, 'embed_dim': 512, 'hidden_dim': [4096, 1024], 'dropout': 0.1, 'residual': False, 'activation': 'prelu', 'attention_hidden_dim': 512, 'num_classes': 10, 'classifier_dropout': 0.1, 'output_dir': '/tmp/vc_smoke_test_wandb', 'overwrite_output_dir': False, 'do_train': False, 'do_eval': True, 'do_predict': False, 'eval_strategy': 'epoch', 'prediction_loss_only': False, 'per_device_train_batch_size': 4, 'per_device_eval_batch_size': 4, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 0.0001, 'weight_decay': 0.05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 2, 'max_steps': -1, 'lr_scheduler_type': 'cosine', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.1, 'warmup_steps': 0, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': '/tmp/vc_smoke_test_wandb/runs/May03_17-12-12_Mac.lan', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 500, 'logging_nan_inf_filter': True, 'save_strategy': 'epoch', 'save_steps': 500, 'save_total_limit': None, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': False, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': None, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': None, 'dataloader_num_workers': 2, 'dataloader_prefetch_factor': 2, 'past_index': -1, 'run_name': 'smoke-test', 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': True, 'metric_for_best_model': 'eval_loss', 'greater_is_better': False, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'tp_size': 0, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': None, 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['wandb'], 'ddp_find_unused_parameters': None, 'ddp_bucket_cap_mb': None, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': True, 'dataloader_persistent_workers': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': None, 'hub_always_push': False, 'gradient_checkpointing': False, 'gradient_checkpointing_kwargs': None, 'include_inputs_for_metrics': False, 'include_for_metrics': [], 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'eval_use_gather_object': False, 'average_tokens_across_devices': False}
23
+ 2026-05-03 17:12:15,654 INFO MainThread:63423 [wandb_config.py:__setitem__():154] [no run ID] config set model/num_parameters = 79963661 - <bound method Run._config_callback of <wandb.sdk.wandb_run.Run object at 0x14e776850>>
24
+ 2026-05-03 17:12:15,654 INFO MainThread:63423 [wandb_run.py:_config_callback():1363] config_cb model/num_parameters 79963661 None
25
+ 2026-05-03 17:14:51,017 INFO MsgRouterThr:63423 [mailbox.py:close():129] [no run ID] Closing mailbox, abandoning 1 handles.
wandb/run-20260503_171213-h9m78x54/files/code/train.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional
6
+
7
+ import numpy as np
8
+ import torch
9
+ from datasets import DatasetDict, load_dataset
10
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
11
+ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
12
+ from transformers.trainer_utils import EvalPrediction
13
+
14
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
+ from modeling_virtual_cell import VirtualCellPatientConfig, VirtualCellPatientModel
16
+
17
+
18
+ @dataclass
19
+ class PatientCollator:
20
+ def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
21
+ return {
22
+ "input_ids": torch.stack([
23
+ torch.tensor(f["input_ids"], dtype=torch.float32) for f in features
24
+ ]),
25
+ "attention_mask": torch.stack([
26
+ torch.tensor(f["attention_mask"], dtype=torch.bool) for f in features
27
+ ]),
28
+ "labels": torch.tensor([f["labels"] for f in features], dtype=torch.long),
29
+ "entity_id": torch.tensor([f["entity_id"] for f in features], dtype=torch.long),
30
+ }
31
+
32
+
33
+ def _patient_predictions(logits: np.ndarray, entity_ids: np.ndarray):
34
+ """Average softmax probabilities across augmented views, one row per patient."""
35
+ entity_ids = np.asarray(entity_ids).astype(str)
36
+ unique = np.unique(entity_ids)
37
+ agg = []
38
+ for eid in unique:
39
+ views = logits[entity_ids == eid]
40
+ exp = np.exp(views - np.max(views, axis=1, keepdims=True))
41
+ agg.append(np.mean(exp / exp.sum(axis=1, keepdims=True), axis=0))
42
+ return np.array(agg), unique
43
+
44
+
45
+ def _clf_metrics(y_true: np.ndarray, y_pred: np.ndarray, prefix: str) -> Dict[str, float]:
46
+ return {
47
+ f"{prefix}accuracy": accuracy_score(y_true, y_pred),
48
+ f"{prefix}f1_macro": f1_score(y_true, y_pred, average="macro", zero_division=0),
49
+ f"{prefix}precision": precision_score(y_true, y_pred, average="macro", zero_division=0),
50
+ f"{prefix}recall": recall_score(y_true, y_pred, average="macro", zero_division=0),
51
+ }
52
+
53
+
54
+ def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
55
+ logits_with_entity = eval_pred.predictions # (N, num_classes + 1)
56
+ logits = logits_with_entity[:, :-1]
57
+ entity_ids = logits_with_entity[:, -1].astype(int)
58
+ labels = eval_pred.label_ids
59
+
60
+ metrics = _clf_metrics(labels, np.argmax(logits, axis=1), "per_view/")
61
+
62
+ patient_preds, unique_entities = _patient_predictions(logits, entity_ids)
63
+ patient_labels = np.array([
64
+ labels[np.where(entity_ids == int(eid))[0][0]]
65
+ for eid in unique_entities
66
+ ])
67
+ metrics.update(_clf_metrics(patient_labels, np.argmax(patient_preds, axis=1), "patient/"))
68
+
69
+ return metrics
70
+
71
+
72
+ class PatientTrainer(Trainer):
73
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
74
+ outputs = model(**inputs)
75
+ return (outputs.loss, outputs) if return_outputs else outputs.loss
76
+
77
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
78
+ entity_id = inputs.pop("entity_id")
79
+ loss, logits, labels = super().prediction_step(
80
+ model, inputs, prediction_loss_only, ignore_keys=ignore_keys
81
+ )
82
+ if logits is not None:
83
+ entity_col = entity_id.float().unsqueeze(1).to(logits.device)
84
+ logits = torch.cat([logits, entity_col], dim=1)
85
+ return loss, logits, labels
86
+
87
+
88
+ def parse_args():
89
+ p = argparse.ArgumentParser()
90
+
91
+ p.add_argument("--dataset_path", required=True,
92
+ help="HF dataset ID or local path with train (and optionally validation) splits")
93
+ p.add_argument("--model_name_or_path", default="ConvergeBio/virtual-cell-patient")
94
+ p.add_argument("--hf_token", default=None)
95
+ p.add_argument("--output_dir", default="./vc_output")
96
+ p.add_argument("--from_scratch", action="store_true")
97
+ p.add_argument("--freeze_embedder", action="store_true")
98
+ p.add_argument("--num_classes", type=int, default=None)
99
+ p.add_argument("--num_train_epochs", type=int, default=15)
100
+ p.add_argument("--per_device_train_batch_size", type=int, default=32)
101
+ p.add_argument("--per_device_eval_batch_size", type=int, default=32)
102
+ p.add_argument("--learning_rate", type=float, default=1e-4)
103
+ p.add_argument("--weight_decay", type=float, default=0.05)
104
+ p.add_argument("--warmup_ratio", type=float, default=0.1)
105
+ p.add_argument("--lr_scheduler_type", default="cosine")
106
+ p.add_argument("--patience", type=int, default=5)
107
+ p.add_argument("--num_workers", type=int, default=4)
108
+ p.add_argument("--prefetch_factor", type=int, default=2)
109
+ p.add_argument("--wandb_project", default=None)
110
+ p.add_argument("--run_name", default=None)
111
+
112
+ return p.parse_args()
113
+
114
+
115
+ def main():
116
+ args = parse_args()
117
+
118
+ if os.path.isdir(args.dataset_path):
119
+ ds = DatasetDict.load_from_disk(args.dataset_path)
120
+ else:
121
+ ds = load_dataset(args.dataset_path, num_proc=args.num_workers)
122
+ train_ds = ds["train"]
123
+ val_ds: Optional[object] = ds.get("validation")
124
+
125
+ hf_kwargs = {"trust_remote_code": True}
126
+ if args.hf_token:
127
+ hf_kwargs["token"] = args.hf_token
128
+
129
+ config = VirtualCellPatientConfig.from_pretrained(args.model_name_or_path, **hf_kwargs)
130
+ if args.num_classes is not None:
131
+ config.num_classes = args.num_classes
132
+ config.id2label = {str(i): str(i) for i in range(args.num_classes)}
133
+ config.label2id = {str(i): i for i in range(args.num_classes)}
134
+
135
+ if args.from_scratch:
136
+ model = VirtualCellPatientModel(config)
137
+ else:
138
+ model = VirtualCellPatientModel.from_pretrained(
139
+ args.model_name_or_path, config=config, **hf_kwargs
140
+ )
141
+
142
+ if args.freeze_embedder:
143
+ for param in model.patient_embedder.parameters():
144
+ param.requires_grad = False
145
+
146
+ if args.wandb_project:
147
+ os.environ["WANDB_PROJECT"] = args.wandb_project
148
+
149
+ has_val = val_ds is not None
150
+ training_args = TrainingArguments(
151
+ output_dir=args.output_dir,
152
+ num_train_epochs=args.num_train_epochs,
153
+ per_device_train_batch_size=args.per_device_train_batch_size,
154
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
155
+ learning_rate=args.learning_rate,
156
+ weight_decay=args.weight_decay,
157
+ warmup_ratio=args.warmup_ratio,
158
+ lr_scheduler_type=args.lr_scheduler_type,
159
+ eval_strategy="epoch" if has_val else "no",
160
+ save_strategy="epoch",
161
+ load_best_model_at_end=has_val,
162
+ metric_for_best_model="eval_loss" if has_val else None,
163
+ greater_is_better=False,
164
+ report_to="wandb" if args.wandb_project else "none",
165
+ run_name=args.run_name,
166
+ dataloader_num_workers=args.num_workers,
167
+ dataloader_prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
168
+ dataloader_persistent_workers=args.num_workers > 0,
169
+ dataloader_pin_memory=True,
170
+ remove_unused_columns=False,
171
+ )
172
+
173
+ callbacks = [EarlyStoppingCallback(args.patience)] if has_val else []
174
+
175
+ trainer = PatientTrainer(
176
+ model=model,
177
+ args=training_args,
178
+ train_dataset=train_ds,
179
+ eval_dataset=val_ds,
180
+ data_collator=PatientCollator(),
181
+ compute_metrics=compute_metrics if has_val else None,
182
+ callbacks=callbacks,
183
+ )
184
+
185
+ trainer.train()
186
+ trainer.save_model(args.output_dir)
187
+
188
+
189
+ if __name__ == "__main__":
190
+ main()
wandb/run-20260503_171213-h9m78x54/files/config.yaml ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _attn_implementation_autoset:
2
+ value: true
3
+ _name_or_path:
4
+ value: /Users/daniellemillersayag/Documents/vcell/paper/hf-release
5
+ _wandb:
6
+ value:
7
+ cli_version: 0.21.0
8
+ code_path: code/train.py
9
+ e:
10
+ fv6s7853m72kjtsdphyqhm5sm6sgz3ly:
11
+ apple: {}
12
+ args:
13
+ - --dataset_path
14
+ - /Users/daniellemillersayag/Documents/vcell/paper/example_dataset
15
+ - --model_name_or_path
16
+ - /Users/daniellemillersayag/Documents/vcell/paper/hf-release
17
+ - --output_dir
18
+ - /tmp/vc_smoke_test_wandb
19
+ - --num_train_epochs
20
+ - "2"
21
+ - --per_device_train_batch_size
22
+ - "4"
23
+ - --per_device_eval_batch_size
24
+ - "4"
25
+ - --num_workers
26
+ - "2"
27
+ - --patience
28
+ - "5"
29
+ - --wandb_project
30
+ - virtual-cell-patient
31
+ - --run_name
32
+ - smoke-test
33
+ codePath: train.py
34
+ codePathLocal: train.py
35
+ cpu_count: 11
36
+ cpu_count_logical: 11
37
+ disk:
38
+ /:
39
+ total: "994662584320"
40
+ used: "276313182208"
41
+ email: danielle.miller@converge-bio.com
42
+ executable: /Users/daniellemillersayag/Documents/Repos/virtual-cell/venv/bin/python
43
+ host: Mac.lan
44
+ memory:
45
+ total: "38654705664"
46
+ os: macOS-26.3.1-arm64-arm-64bit
47
+ program: /Users/daniellemillersayag/Documents/vcell/paper/hf-release/train.py
48
+ python: CPython 3.11.10
49
+ root: /Users/daniellemillersayag/Documents/vcell/paper/hf-release
50
+ startedAt: "2026-05-03T14:12:13.849133Z"
51
+ writerId: fv6s7853m72kjtsdphyqhm5sm6sgz3ly
52
+ m:
53
+ - "1": train/global_step
54
+ "6":
55
+ - 3
56
+ "7": []
57
+ - "2": '*'
58
+ "5": 1
59
+ "6":
60
+ - 1
61
+ "7": []
62
+ python_version: 3.11.10
63
+ t:
64
+ "1":
65
+ - 1
66
+ - 5
67
+ - 11
68
+ - 12
69
+ - 49
70
+ - 51
71
+ - 53
72
+ - 71
73
+ "2":
74
+ - 1
75
+ - 5
76
+ - 11
77
+ - 12
78
+ - 49
79
+ - 51
80
+ - 53
81
+ - 71
82
+ "3":
83
+ - 7
84
+ - 13
85
+ - 19
86
+ - 62
87
+ - 66
88
+ "4": 3.11.10
89
+ "5": 0.21.0
90
+ "6": 4.51.3
91
+ "9":
92
+ "1": transformers_trainer
93
+ "12": 0.21.0
94
+ "13": darwin-arm64
95
+ accelerator_config:
96
+ value:
97
+ dispatch_batches: null
98
+ even_batches: true
99
+ gradient_accumulation_kwargs: null
100
+ non_blocking: false
101
+ split_batches: false
102
+ use_seedable_sampler: true
103
+ activation:
104
+ value: prelu
105
+ adafactor:
106
+ value: false
107
+ adam_beta1:
108
+ value: 0.9
109
+ adam_beta2:
110
+ value: 0.999
111
+ adam_epsilon:
112
+ value: 1e-08
113
+ add_cross_attention:
114
+ value: false
115
+ architectures:
116
+ value:
117
+ - VirtualCellPatientModel
118
+ attention_hidden_dim:
119
+ value: 512
120
+ auto_find_batch_size:
121
+ value: false
122
+ auto_map:
123
+ value:
124
+ AutoConfig: modeling_virtual_cell.VirtualCellPatientConfig
125
+ AutoModel: modeling_virtual_cell.VirtualCellPatientModel
126
+ average_tokens_across_devices:
127
+ value: false
128
+ bad_words_ids:
129
+ value: null
130
+ batch_eval_metrics:
131
+ value: false
132
+ begin_suppress_tokens:
133
+ value: null
134
+ bf16:
135
+ value: false
136
+ bf16_full_eval:
137
+ value: false
138
+ bos_token_id:
139
+ value: null
140
+ chunk_size_feed_forward:
141
+ value: 0
142
+ classifier_dropout:
143
+ value: 0.1
144
+ cross_attention_hidden_size:
145
+ value: null
146
+ data_seed:
147
+ value: null
148
+ dataloader_drop_last:
149
+ value: false
150
+ dataloader_num_workers:
151
+ value: 2
152
+ dataloader_persistent_workers:
153
+ value: true
154
+ dataloader_pin_memory:
155
+ value: true
156
+ dataloader_prefetch_factor:
157
+ value: 2
158
+ ddp_backend:
159
+ value: null
160
+ ddp_broadcast_buffers:
161
+ value: null
162
+ ddp_bucket_cap_mb:
163
+ value: null
164
+ ddp_find_unused_parameters:
165
+ value: null
166
+ ddp_timeout:
167
+ value: 1800
168
+ debug:
169
+ value: []
170
+ decoder_start_token_id:
171
+ value: null
172
+ deepspeed:
173
+ value: null
174
+ disable_tqdm:
175
+ value: false
176
+ diversity_penalty:
177
+ value: 0
178
+ do_eval:
179
+ value: true
180
+ do_predict:
181
+ value: false
182
+ do_sample:
183
+ value: false
184
+ do_train:
185
+ value: false
186
+ dropout:
187
+ value: 0.1
188
+ early_stopping:
189
+ value: false
190
+ embed_dim:
191
+ value: 512
192
+ encoder_no_repeat_ngram_size:
193
+ value: 0
194
+ eos_token_id:
195
+ value: null
196
+ eval_accumulation_steps:
197
+ value: null
198
+ eval_delay:
199
+ value: 0
200
+ eval_do_concat_batches:
201
+ value: true
202
+ eval_on_start:
203
+ value: false
204
+ eval_steps:
205
+ value: null
206
+ eval_strategy:
207
+ value: epoch
208
+ eval_use_gather_object:
209
+ value: false
210
+ exponential_decay_length_penalty:
211
+ value: null
212
+ finetuning_task:
213
+ value: null
214
+ forced_bos_token_id:
215
+ value: null
216
+ forced_eos_token_id:
217
+ value: null
218
+ fp16:
219
+ value: false
220
+ fp16_backend:
221
+ value: auto
222
+ fp16_full_eval:
223
+ value: false
224
+ fp16_opt_level:
225
+ value: O1
226
+ fsdp:
227
+ value: []
228
+ fsdp_config:
229
+ value:
230
+ min_num_params: 0
231
+ xla: false
232
+ xla_fsdp_grad_ckpt: false
233
+ xla_fsdp_v2: false
234
+ fsdp_min_num_params:
235
+ value: 0
236
+ fsdp_transformer_layer_cls_to_wrap:
237
+ value: null
238
+ full_determinism:
239
+ value: false
240
+ gradient_accumulation_steps:
241
+ value: 1
242
+ gradient_checkpointing:
243
+ value: false
244
+ gradient_checkpointing_kwargs:
245
+ value: null
246
+ greater_is_better:
247
+ value: false
248
+ group_by_length:
249
+ value: false
250
+ half_precision_backend:
251
+ value: auto
252
+ hidden_dim:
253
+ value:
254
+ - 4096
255
+ - 1024
256
+ hub_always_push:
257
+ value: false
258
+ hub_model_id:
259
+ value: null
260
+ hub_private_repo:
261
+ value: null
262
+ hub_strategy:
263
+ value: every_save
264
+ hub_token:
265
+ value: <HUB_TOKEN>
266
+ id2label:
267
+ value:
268
+ "0": oncological
269
+ "1": immune_inflammatory
270
+ "2": neurological
271
+ "3": metabolic_vascular
272
+ "4": gastrointestinal
273
+ "5": respiratory
274
+ "6": epithelial_barrier
275
+ "7": sensory_specialized
276
+ "8": healthy_control
277
+ "9": other
278
+ ignore_data_skip:
279
+ value: false
280
+ include_for_metrics:
281
+ value: []
282
+ include_inputs_for_metrics:
283
+ value: false
284
+ include_num_input_tokens_seen:
285
+ value: false
286
+ include_tokens_per_second:
287
+ value: false
288
+ is_decoder:
289
+ value: false
290
+ is_encoder_decoder:
291
+ value: false
292
+ jit_mode_eval:
293
+ value: false
294
+ label_names:
295
+ value: null
296
+ label_smoothing_factor:
297
+ value: 0
298
+ label2id:
299
+ value:
300
+ epithelial_barrier: 6
301
+ gastrointestinal: 4
302
+ healthy_control: 8
303
+ immune_inflammatory: 1
304
+ metabolic_vascular: 3
305
+ neurological: 2
306
+ oncological: 0
307
+ other: 9
308
+ respiratory: 5
309
+ sensory_specialized: 7
310
+ learning_rate:
311
+ value: 0.0001
312
+ length_column_name:
313
+ value: length
314
+ length_penalty:
315
+ value: 1
316
+ load_best_model_at_end:
317
+ value: true
318
+ local_rank:
319
+ value: 0
320
+ log_level:
321
+ value: passive
322
+ log_level_replica:
323
+ value: warning
324
+ log_on_each_node:
325
+ value: true
326
+ logging_dir:
327
+ value: /tmp/vc_smoke_test_wandb/runs/May03_17-12-12_Mac.lan
328
+ logging_first_step:
329
+ value: false
330
+ logging_nan_inf_filter:
331
+ value: true
332
+ logging_steps:
333
+ value: 500
334
+ logging_strategy:
335
+ value: steps
336
+ lr_scheduler_type:
337
+ value: cosine
338
+ max_grad_norm:
339
+ value: 1
340
+ max_length:
341
+ value: 20
342
+ max_steps:
343
+ value: -1
344
+ metric_for_best_model:
345
+ value: eval_loss
346
+ min_length:
347
+ value: 0
348
+ model/num_parameters:
349
+ value: 79963661
350
+ model_type:
351
+ value: virtual_cell_patient
352
+ mp_parameters:
353
+ value: ""
354
+ n_genes:
355
+ value: 18301
356
+ neftune_noise_alpha:
357
+ value: null
358
+ no_cuda:
359
+ value: false
360
+ no_repeat_ngram_size:
361
+ value: 0
362
+ num_beam_groups:
363
+ value: 1
364
+ num_beams:
365
+ value: 1
366
+ num_classes:
367
+ value: 10
368
+ num_return_sequences:
369
+ value: 1
370
+ num_train_epochs:
371
+ value: 2
372
+ optim:
373
+ value: adamw_torch
374
+ optim_args:
375
+ value: null
376
+ optim_target_modules:
377
+ value: null
378
+ output_attentions:
379
+ value: false
380
+ output_dir:
381
+ value: /tmp/vc_smoke_test_wandb
382
+ output_hidden_states:
383
+ value: false
384
+ output_scores:
385
+ value: false
386
+ overwrite_output_dir:
387
+ value: false
388
+ pad_token_id:
389
+ value: null
390
+ past_index:
391
+ value: -1
392
+ per_device_eval_batch_size:
393
+ value: 4
394
+ per_device_train_batch_size:
395
+ value: 4
396
+ per_gpu_eval_batch_size:
397
+ value: null
398
+ per_gpu_train_batch_size:
399
+ value: null
400
+ prediction_loss_only:
401
+ value: false
402
+ prefix:
403
+ value: null
404
+ problem_type:
405
+ value: null
406
+ push_to_hub:
407
+ value: false
408
+ push_to_hub_model_id:
409
+ value: null
410
+ push_to_hub_organization:
411
+ value: null
412
+ push_to_hub_token:
413
+ value: <PUSH_TO_HUB_TOKEN>
414
+ ray_scope:
415
+ value: last
416
+ remove_invalid_values:
417
+ value: false
418
+ remove_unused_columns:
419
+ value: false
420
+ repetition_penalty:
421
+ value: 1
422
+ report_to:
423
+ value:
424
+ - wandb
425
+ residual:
426
+ value: false
427
+ restore_callback_states_from_checkpoint:
428
+ value: false
429
+ resume_from_checkpoint:
430
+ value: null
431
+ return_dict:
432
+ value: true
433
+ return_dict_in_generate:
434
+ value: false
435
+ run_name:
436
+ value: smoke-test
437
+ save_on_each_node:
438
+ value: false
439
+ save_only_model:
440
+ value: false
441
+ save_safetensors:
442
+ value: true
443
+ save_steps:
444
+ value: 500
445
+ save_strategy:
446
+ value: epoch
447
+ save_total_limit:
448
+ value: null
449
+ seed:
450
+ value: 42
451
+ sep_token_id:
452
+ value: null
453
+ skip_memory_metrics:
454
+ value: true
455
+ suppress_tokens:
456
+ value: null
457
+ task_specific_params:
458
+ value: null
459
+ temperature:
460
+ value: 1
461
+ tf_legacy_loss:
462
+ value: false
463
+ tf32:
464
+ value: null
465
+ tie_encoder_decoder:
466
+ value: false
467
+ tie_word_embeddings:
468
+ value: true
469
+ tokenizer_class:
470
+ value: null
471
+ top_k:
472
+ value: 50
473
+ top_p:
474
+ value: 1
475
+ torch_compile:
476
+ value: false
477
+ torch_compile_backend:
478
+ value: null
479
+ torch_compile_mode:
480
+ value: null
481
+ torch_dtype:
482
+ value: float32
483
+ torch_empty_cache_steps:
484
+ value: null
485
+ torchdynamo:
486
+ value: null
487
+ torchscript:
488
+ value: false
489
+ tp_size:
490
+ value: 0
491
+ tpu_metrics_debug:
492
+ value: false
493
+ tpu_num_cores:
494
+ value: null
495
+ transformers_version:
496
+ value: 4.51.3
497
+ typical_p:
498
+ value: 1
499
+ use_bfloat16:
500
+ value: false
501
+ use_cpu:
502
+ value: false
503
+ use_ipex:
504
+ value: false
505
+ use_legacy_prediction_loop:
506
+ value: false
507
+ use_liger_kernel:
508
+ value: false
509
+ use_mps_device:
510
+ value: false
511
+ warmup_ratio:
512
+ value: 0.1
513
+ warmup_steps:
514
+ value: 0
515
+ weight_decay:
516
+ value: 0.05
wandb/run-20260503_171213-h9m78x54/files/output.log ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 100%|██████████| 20/20 [02:14<00:00, 6.74s/it]
2
+
3
+ {'eval_loss': 2.8007757663726807, 'eval_per_view/accuracy': 0.3333333333333333, 'eval_per_view/f1_macro': 0.25, 'eval_per_view/precision': 0.25, 'eval_per_view/recall': 0.25, 'eval_patient/accuracy': 0.3333333333333333, 'eval_patient/f1_macro': 0.25, 'eval_patient/precision': 0.25, 'eval_patient/recall': 0.25, 'eval_runtime': 20.6865, 'eval_samples_per_second': 0.725, 'eval_steps_per_second': 0.193, 'epoch': 1.0}
4
+ {'eval_loss': 3.5730626583099365, 'eval_per_view/accuracy': 0.3333333333333333, 'eval_per_view/f1_macro': 0.25, 'eval_per_view/precision': 0.25, 'eval_per_view/recall': 0.25, 'eval_patient/accuracy': 0.3333333333333333, 'eval_patient/f1_macro': 0.25, 'eval_patient/precision': 0.25, 'eval_patient/recall': 0.25, 'eval_runtime': 21.0117, 'eval_samples_per_second': 0.714, 'eval_steps_per_second': 0.19, 'epoch': 2.0}
5
+ {'train_runtime': 137.2132, 'train_samples_per_second': 0.583, 'train_steps_per_second': 0.146, 'train_loss': 0.570319652557373, 'epoch': 2.0}
wandb/run-20260503_171213-h9m78x54/files/requirements.txt ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ propcache==0.3.2
2
+ soupsieve==2.7
3
+ fsspec==2026.4.0
4
+ contourpy==1.3.2
5
+ arrow==1.3.0
6
+ s3fs==2026.4.0
7
+ threadpoolctl==3.6.0
8
+ uri-template==1.3.0
9
+ rfc3339-validator==0.1.4
10
+ pydantic_core==2.33.2
11
+ sparse==0.17.0
12
+ flax==0.10.7
13
+ pexpect==4.9.0
14
+ argcomplete==3.6.3
15
+ optax==0.2.5
16
+ transformers==4.51.3
17
+ fqdn==1.5.1
18
+ llvmlite==0.44.0
19
+ jupyter_core==5.8.1
20
+ idna==3.10
21
+ babel==2.17.0
22
+ traitlets==5.14.3
23
+ aioitertools==0.13.0
24
+ urllib3==2.5.0
25
+ jupyter_server==2.16.0
26
+ nest-asyncio==1.6.0
27
+ pyro-ppl==1.9.1
28
+ debugpy==1.8.14
29
+ kiwisolver==1.4.8
30
+ ptyprocess==0.7.0
31
+ jaxlib==0.6.2
32
+ isoduration==20.11.0
33
+ nox==2026.2.9
34
+ certifi==2025.7.14
35
+ pytz==2025.2
36
+ narwhals==1.47.0
37
+ toml==0.10.2
38
+ legacy-api-wrap==1.4.1
39
+ dependency-groups==1.3.1
40
+ joblib==1.5.1
41
+ protobuf==6.31.1
42
+ typing-inspection==0.4.1
43
+ multidict==6.6.3
44
+ virtualenv==21.2.0
45
+ overrides==7.7.0
46
+ parso==0.8.4
47
+ webencodings==0.5.1
48
+ tinycss2==1.4.0
49
+ jupyterlab==4.4.4
50
+ array-api-compat==1.12.0
51
+ docrep==0.3.2
52
+ hf-xet==1.1.5
53
+ pydeck==0.9.1
54
+ ipython==9.4.0
55
+ beautifulsoup4==4.13.4
56
+ wandb==0.21.0
57
+ h5py==3.14.0
58
+ mudata==0.3.2
59
+ platformdirs==4.3.8
60
+ wrapt==2.1.2
61
+ opt_einsum==3.4.0
62
+ anyio==4.9.0
63
+ defusedxml==0.7.1
64
+ ipywidgets==8.1.7
65
+ pip==24.0
66
+ attrs==25.3.0
67
+ pandas==2.3.1
68
+ natsort==8.4.0
69
+ async-lru==2.0.5
70
+ blinker==1.9.0
71
+ tenacity==9.1.2
72
+ notebook==7.4.4
73
+ markdown-it-py==3.0.0
74
+ seaborn==0.13.2
75
+ gseapy==1.1.12
76
+ GitPython==3.1.44
77
+ pyparsing==3.2.3
78
+ pyzmq==27.0.0
79
+ python-pptx==1.0.2
80
+ jupyter-console==6.6.3
81
+ jedi==0.19.2
82
+ pytest==8.4.1
83
+ charset-normalizer==3.4.2
84
+ numpyro==0.18.0
85
+ requests==2.32.4
86
+ torchmetrics==1.7.4
87
+ typing_extensions==4.14.1
88
+ jupyter==1.1.1
89
+ numba==0.61.2
90
+ patsy==1.0.1
91
+ aiohttp==3.12.14
92
+ treescope==0.1.9
93
+ jupyter_client==8.6.3
94
+ distlib==0.4.0
95
+ pynndescent==0.5.13
96
+ asttokens==3.0.0
97
+ tqdm==4.67.1
98
+ matplotlib==3.10.3
99
+ pandocfilters==1.5.1
100
+ prometheus_client==0.22.1
101
+ json5==0.12.0
102
+ huggingface-hub==0.33.4
103
+ fastjsonschema==2.21.1
104
+ jsonpointer==3.0.0
105
+ tzdata==2025.2
106
+ ipython_pygments_lexers==1.1.1
107
+ appnope==0.1.4
108
+ lightning==2.5.2
109
+ numpy==1.26.4
110
+ jax==0.6.2
111
+ httpcore==1.0.9
112
+ filelock==3.25.2
113
+ decorator==5.2.1
114
+ msgpack==1.1.1
115
+ cffi==1.17.1
116
+ captum==0.8.0
117
+ executing==2.2.0
118
+ nbformat==5.10.4
119
+ jupyterlab_widgets==3.0.15
120
+ yarl==1.20.1
121
+ setuptools==65.5.0
122
+ umap-learn==0.5.9.post2
123
+ aiobotocore==3.6.0
124
+ stack-data==0.6.3
125
+ jmespath==1.1.0
126
+ tensorboard==2.19.0
127
+ multiprocess==0.70.16
128
+ Werkzeug==3.1.3
129
+ jsonschema==4.24.0
130
+ xxhash==3.5.0
131
+ nbconvert==7.16.6
132
+ referencing==0.36.2
133
+ regex==2024.11.6
134
+ absl-py==2.3.1
135
+ sentry-sdk==2.33.1
136
+ Send2Trash==1.8.3
137
+ jupyter-lsp==2.2.5
138
+ python-dotenv==1.2.2
139
+ scvi-tools==1.3.2
140
+ nbclient==0.10.2
141
+ h11==0.16.0
142
+ gitdb==4.0.12
143
+ sniffio==1.3.1
144
+ simplejson==3.20.1
145
+ psutil==7.0.0
146
+ fonttools==4.58.5
147
+ rpds-py==0.26.0
148
+ mdurl==0.1.2
149
+ magika==0.6.3
150
+ networkx==3.5
151
+ python-dateutil==2.9.0.post0
152
+ colorlog==6.10.1
153
+ mpmath==1.3.0
154
+ jupyterlab_pygments==0.3.0
155
+ mistune==3.1.3
156
+ torch==2.5.1
157
+ anndata==0.11.4
158
+ wcwidth==0.2.13
159
+ streamlit==1.51.0
160
+ markdownify==1.2.2
161
+ scikit-learn==1.7.0
162
+ tokenizers==0.21.2
163
+ jupyter-events==0.12.0
164
+ prompt_toolkit==3.0.51
165
+ botocore==1.43.0
166
+ aiosignal==1.4.0
167
+ grpcio==1.73.1
168
+ plotly==6.2.0
169
+ toolz==1.0.0
170
+ click==8.2.1
171
+ lightning-utilities==0.14.3
172
+ packaging==25.0
173
+ jupyterlab_server==2.27.3
174
+ argon2-cffi==25.1.0
175
+ webcolors==24.11.1
176
+ jsonschema-specifications==2025.4.1
177
+ pycparser==2.22
178
+ cycler==0.12.1
179
+ Jinja2==3.1.6
180
+ tornado==6.5.1
181
+ session-info2==0.1.2
182
+ dill==0.3.8
183
+ comm==0.2.2
184
+ multipledispatch==1.0.0
185
+ pure_eval==0.2.3
186
+ pydantic==2.11.7
187
+ flatbuffers==25.12.19
188
+ pluggy==1.6.0
189
+ Pygments==2.19.2
190
+ etils==1.13.0
191
+ rfc3986-validator==0.1.1
192
+ python-discovery==1.2.1
193
+ aiohappyeyeballs==2.6.1
194
+ python-json-logger==3.3.0
195
+ terminado==0.18.1
196
+ xgboost==3.1.1
197
+ types-python-dateutil==2.9.0.20250708
198
+ sympy==1.13.1
199
+ argon2-cffi-bindings==21.2.0
200
+ xlsxwriter==3.2.9
201
+ PyYAML==6.0.2
202
+ httpx==0.28.1
203
+ humanize==4.12.3
204
+ lxml==6.1.0
205
+ rich==14.0.0
206
+ matplotlib-inline==0.1.7
207
+ smmap==5.0.2
208
+ matplotlib-venn==1.1.2
209
+ safetensors==0.5.3
210
+ xarray==2025.7.1
211
+ pillow==11.3.0
212
+ ml_collections==1.1.0
213
+ tensorboard-data-server==0.7.2
214
+ pytorch-lightning==2.5.2
215
+ pyro-api==0.1.2
216
+ scipy==1.15.3
217
+ jupyter_server_terminals==0.5.3
218
+ bleach==6.2.0
219
+ orbax-checkpoint==0.11.19
220
+ ml_dtypes==0.5.1
221
+ altair==5.5.0
222
+ tensorstore==0.1.76
223
+ iniconfig==2.1.0
224
+ ipykernel==6.29.5
225
+ zipp==3.23.0
226
+ annotated-types==0.7.0
227
+ scanpy==1.11.3
228
+ datasets==3.2.0
229
+ widgetsnbextension==4.0.14
230
+ Markdown==3.8.2
231
+ six==1.17.0
232
+ importlib_resources==6.5.2
233
+ chex==0.1.89
234
+ pyarrow==20.0.0
235
+ markitdown==0.1.5
236
+ statsmodels==0.14.5
237
+ cachetools==6.2.1
238
+ notebook_shim==0.2.4
239
+ frozenlist==1.7.0
240
+ onnxruntime==1.25.1
241
+ accelerate==1.1.1
242
+ websocket-client==1.8.0
243
+ MarkupSafe==3.0.2
wandb/run-20260503_171213-h9m78x54/files/wandb-metadata.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "macOS-26.3.1-arm64-arm-64bit",
3
+ "python": "CPython 3.11.10",
4
+ "startedAt": "2026-05-03T14:12:13.849133Z",
5
+ "args": [
6
+ "--dataset_path",
7
+ "/Users/daniellemillersayag/Documents/vcell/paper/example_dataset",
8
+ "--model_name_or_path",
9
+ "/Users/daniellemillersayag/Documents/vcell/paper/hf-release",
10
+ "--output_dir",
11
+ "/tmp/vc_smoke_test_wandb",
12
+ "--num_train_epochs",
13
+ "2",
14
+ "--per_device_train_batch_size",
15
+ "4",
16
+ "--per_device_eval_batch_size",
17
+ "4",
18
+ "--num_workers",
19
+ "2",
20
+ "--patience",
21
+ "5",
22
+ "--wandb_project",
23
+ "virtual-cell-patient",
24
+ "--run_name",
25
+ "smoke-test"
26
+ ],
27
+ "program": "/Users/daniellemillersayag/Documents/vcell/paper/hf-release/train.py",
28
+ "codePath": "train.py",
29
+ "codePathLocal": "train.py",
30
+ "email": "danielle.miller@converge-bio.com",
31
+ "root": "/Users/daniellemillersayag/Documents/vcell/paper/hf-release",
32
+ "host": "Mac.lan",
33
+ "executable": "/Users/daniellemillersayag/Documents/Repos/virtual-cell/venv/bin/python",
34
+ "cpu_count": 11,
35
+ "cpu_count_logical": 11,
36
+ "disk": {
37
+ "/": {
38
+ "total": "994662584320",
39
+ "used": "276313182208"
40
+ }
41
+ },
42
+ "memory": {
43
+ "total": "38654705664"
44
+ },
45
+ "apple": {},
46
+ "writerId": "fv6s7853m72kjtsdphyqhm5sm6sgz3ly"
47
+ }
wandb/run-20260503_171213-h9m78x54/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"train/epoch":2,"eval/per_view/f1_macro":0.25,"train_loss":0.570319652557373,"eval/patient/accuracy":0.3333333333333333,"eval/samples_per_second":0.714,"train_steps_per_second":0.146,"_wandb":{"runtime":155},"train/global_step":20,"_timestamp":1.7778176704983711e+09,"_runtime":155,"train_samples_per_second":0.583,"eval/per_view/accuracy":0.3333333333333333,"eval/patient/precision":0.25,"eval/per_view/precision":0.25,"eval/patient/f1_macro":0.25,"total_flos":3.5121959039064e+17,"eval/loss":3.5730626583099365,"train_runtime":137.2132,"eval/per_view/recall":0.25,"eval/runtime":21.0117,"_step":2,"eval/patient/recall":0.25,"eval/steps_per_second":0.19}
wandb/run-20260503_171213-h9m78x54/logs/debug-core.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-05-03T17:12:14.46496+03:00","level":"INFO","msg":"main: starting server","port-filename":"/var/folders/rp/15xk3vwn341d11km1j04wfvm0000gn/T/tmpwtefwqry/port-63423.txt","pid":63423,"log-level":0,"disable-analytics":false,"shutdown-on-parent-exit":false,"enable-dcgm-profiling":false}
2
+ {"time":"2026-05-03T17:12:14.46556+03:00","level":"INFO","msg":"server: will exit if parent process dies","ppid":63423}
3
+ {"time":"2026-05-03T17:12:14.465548+03:00","level":"INFO","msg":"server: accepting connections","addr":{"Name":"/var/folders/rp/15xk3vwn341d11km1j04wfvm0000gn/T/wandb-63423-63522-633304008/socket","Net":"unix"}}
4
+ {"time":"2026-05-03T17:12:14.49552+03:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"1"}
5
+ {"time":"2026-05-03T17:12:14.512233+03:00","level":"INFO","msg":"handleInformInit: received","streamId":"h9m78x54","id":"1"}
6
+ {"time":"2026-05-03T17:12:15.049479+03:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"h9m78x54","id":"1"}
7
+ {"time":"2026-05-03T17:14:51.019456+03:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"1"}
8
+ {"time":"2026-05-03T17:14:51.019791+03:00","level":"INFO","msg":"server is shutting down"}
9
+ {"time":"2026-05-03T17:14:51.01977+03:00","level":"INFO","msg":"connection: closing","id":"1"}
10
+ {"time":"2026-05-03T17:14:51.019978+03:00","level":"INFO","msg":"connection: closed successfully","id":"1"}
11
+ {"time":"2026-05-03T17:14:51.020159+03:00","level":"INFO","msg":"server: listener closed","addr":{"Name":"/var/folders/rp/15xk3vwn341d11km1j04wfvm0000gn/T/wandb-63423-63522-633304008/socket","Net":"unix"}}
12
+ {"time":"2026-05-03T17:14:52.003896+03:00","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"1"}
13
+ {"time":"2026-05-03T17:14:52.003946+03:00","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"1"}
14
+ {"time":"2026-05-03T17:14:52.003972+03:00","level":"INFO","msg":"server is closed"}
wandb/run-20260503_171213-h9m78x54/logs/debug-internal.log ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-05-03T17:12:14.512438+03:00","level":"INFO","msg":"stream: starting","core version":"0.21.0"}
2
+ {"time":"2026-05-03T17:12:15.049447+03:00","level":"INFO","msg":"stream: created new stream","id":"h9m78x54"}
3
+ {"time":"2026-05-03T17:12:15.049472+03:00","level":"INFO","msg":"stream: started","id":"h9m78x54"}
4
+ {"time":"2026-05-03T17:12:15.049488+03:00","level":"INFO","msg":"writer: Do: started","stream_id":"h9m78x54"}
5
+ {"time":"2026-05-03T17:12:15.049533+03:00","level":"INFO","msg":"sender: started","stream_id":"h9m78x54"}
6
+ {"time":"2026-05-03T17:12:15.049551+03:00","level":"INFO","msg":"handler: started","stream_id":"h9m78x54"}
7
+ {"time":"2026-05-03T17:12:15.531811+03:00","level":"ERROR","msg":"git repo not found","error":"repository does not exist"}
8
+ {"time":"2026-05-03T17:14:51.01985+03:00","level":"INFO","msg":"stream: closing","id":"h9m78x54"}
9
+ {"time":"2026-05-03T17:14:51.643326+03:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
10
+ {"time":"2026-05-03T17:14:51.997995+03:00","level":"INFO","msg":"sender: closed","stream_id":"h9m78x54"}
11
+ {"time":"2026-05-03T17:14:51.998011+03:00","level":"INFO","msg":"handler: closed","stream_id":"h9m78x54"}
12
+ {"time":"2026-05-03T17:14:51.998039+03:00","level":"INFO","msg":"writer: Close: closed","stream_id":"h9m78x54"}
13
+ {"time":"2026-05-03T17:14:51.998606+03:00","level":"INFO","msg":"stream: closed","id":"h9m78x54"}
wandb/run-20260503_171213-h9m78x54/logs/debug.log ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-05-03 17:12:13,856 INFO MainThread:63423 [wandb_setup.py:_flush():80] Current SDK version is 0.21.0
2
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Configure stats pid to 63423
3
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Loading settings from /Users/daniellemillersayag/.config/wandb/settings
4
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Loading settings from /Users/daniellemillersayag/Documents/vcell/paper/hf-release/wandb/settings
5
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_setup.py:_flush():80] Loading settings from environment variables
6
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_init.py:setup_run_log_directory():703] Logging user logs to /Users/daniellemillersayag/Documents/vcell/paper/hf-release/wandb/run-20260503_171213-h9m78x54/logs/debug.log
7
+ 2026-05-03 17:12:13,857 INFO MainThread:63423 [wandb_init.py:setup_run_log_directory():704] Logging internal logs to /Users/daniellemillersayag/Documents/vcell/paper/hf-release/wandb/run-20260503_171213-h9m78x54/logs/debug-internal.log
8
+ 2026-05-03 17:12:13,858 INFO MainThread:63423 [wandb_init.py:init():830] calling init triggers
9
+ 2026-05-03 17:12:13,858 INFO MainThread:63423 [wandb_init.py:init():835] wandb.init called with sweep_config: {}
10
+ config: {'_wandb': {'code_path': 'code/train.py'}}
11
+ 2026-05-03 17:12:13,858 INFO MainThread:63423 [wandb_init.py:init():871] starting backend
12
+ 2026-05-03 17:12:14,495 INFO MainThread:63423 [wandb_init.py:init():874] sending inform_init request
13
+ 2026-05-03 17:12:14,511 INFO MainThread:63423 [wandb_init.py:init():882] backend started and connected
14
+ 2026-05-03 17:12:14,513 INFO MainThread:63423 [wandb_init.py:init():953] updated telemetry
15
+ 2026-05-03 17:12:14,513 INFO MainThread:63423 [wandb_init.py:init():977] communicating run to backend with 90.0 second timeout
16
+ 2026-05-03 17:12:15,529 INFO MainThread:63423 [wandb_init.py:init():1029] starting run threads in backend
17
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_console_start():2458] atexit reg
18
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_redirect():2306] redirect: wrap_raw
19
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_redirect():2375] Wrapping output streams.
20
+ 2026-05-03 17:12:15,651 INFO MainThread:63423 [wandb_run.py:_redirect():2398] Redirects installed.
21
+ 2026-05-03 17:12:15,652 INFO MainThread:63423 [wandb_init.py:init():1075] run started, returning control to user process
22
+ 2026-05-03 17:12:15,653 INFO MainThread:63423 [wandb_run.py:_config_callback():1363] config_cb None None {'return_dict': True, 'output_hidden_states': False, 'output_attentions': False, 'torchscript': False, 'torch_dtype': 'float32', 'use_bfloat16': False, 'tf_legacy_loss': False, 'pruned_heads': {}, 'tie_word_embeddings': True, 'chunk_size_feed_forward': 0, 'is_encoder_decoder': False, 'is_decoder': False, 'cross_attention_hidden_size': None, 'add_cross_attention': False, 'tie_encoder_decoder': False, 'max_length': 20, 'min_length': 0, 'do_sample': False, 'early_stopping': False, 'num_beams': 1, 'num_beam_groups': 1, 'diversity_penalty': 0.0, 'temperature': 1.0, 'top_k': 50, 'top_p': 1.0, 'typical_p': 1.0, 'repetition_penalty': 1.0, 'length_penalty': 1.0, 'no_repeat_ngram_size': 0, 'encoder_no_repeat_ngram_size': 0, 'bad_words_ids': None, 'num_return_sequences': 1, 'output_scores': False, 'return_dict_in_generate': False, 'forced_bos_token_id': None, 'forced_eos_token_id': None, 'remove_invalid_values': False, 'exponential_decay_length_penalty': None, 'suppress_tokens': None, 'begin_suppress_tokens': None, 'architectures': ['VirtualCellPatientModel'], 'finetuning_task': None, 'id2label': {0: 'oncological', 1: 'immune_inflammatory', 2: 'neurological', 3: 'metabolic_vascular', 4: 'gastrointestinal', 5: 'respiratory', 6: 'epithelial_barrier', 7: 'sensory_specialized', 8: 'healthy_control', 9: 'other'}, 'label2id': {'oncological': 0, 'immune_inflammatory': 1, 'neurological': 2, 'metabolic_vascular': 3, 'gastrointestinal': 4, 'respiratory': 5, 'epithelial_barrier': 6, 'sensory_specialized': 7, 'healthy_control': 8, 'other': 9}, 'tokenizer_class': None, 'prefix': None, 'bos_token_id': None, 'pad_token_id': None, 'eos_token_id': None, 'sep_token_id': None, 'decoder_start_token_id': None, 'task_specific_params': None, 'problem_type': None, '_name_or_path': '/Users/daniellemillersayag/Documents/vcell/paper/hf-release', '_attn_implementation_autoset': True, 'transformers_version': '4.51.3', 'model_type': 'virtual_cell_patient', 'auto_map': {'AutoConfig': 'modeling_virtual_cell.VirtualCellPatientConfig', 'AutoModel': 'modeling_virtual_cell.VirtualCellPatientModel'}, 'n_genes': 18301, 'embed_dim': 512, 'hidden_dim': [4096, 1024], 'dropout': 0.1, 'residual': False, 'activation': 'prelu', 'attention_hidden_dim': 512, 'num_classes': 10, 'classifier_dropout': 0.1, 'output_dir': '/tmp/vc_smoke_test_wandb', 'overwrite_output_dir': False, 'do_train': False, 'do_eval': True, 'do_predict': False, 'eval_strategy': 'epoch', 'prediction_loss_only': False, 'per_device_train_batch_size': 4, 'per_device_eval_batch_size': 4, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'eval_delay': 0, 'torch_empty_cache_steps': None, 'learning_rate': 0.0001, 'weight_decay': 0.05, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 2, 'max_steps': -1, 'lr_scheduler_type': 'cosine', 'lr_scheduler_kwargs': {}, 'warmup_ratio': 0.1, 'warmup_steps': 0, 'log_level': 'passive', 'log_level_replica': 'warning', 'log_on_each_node': True, 'logging_dir': '/tmp/vc_smoke_test_wandb/runs/May03_17-12-12_Mac.lan', 'logging_strategy': 'steps', 'logging_first_step': False, 'logging_steps': 500, 'logging_nan_inf_filter': True, 'save_strategy': 'epoch', 'save_steps': 500, 'save_total_limit': None, 'save_safetensors': True, 'save_on_each_node': False, 'save_only_model': False, 'restore_callback_states_from_checkpoint': False, 'no_cuda': False, 'use_cpu': False, 'use_mps_device': False, 'seed': 42, 'data_seed': None, 'jit_mode_eval': False, 'use_ipex': False, 'bf16': False, 'fp16': False, 'fp16_opt_level': 'O1', 'half_precision_backend': 'auto', 'bf16_full_eval': False, 'fp16_full_eval': False, 'tf32': None, 'local_rank': 0, 'ddp_backend': None, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': None, 'dataloader_num_workers': 2, 'dataloader_prefetch_factor': 2, 'past_index': -1, 'run_name': 'smoke-test', 'disable_tqdm': False, 'remove_unused_columns': False, 'label_names': None, 'load_best_model_at_end': True, 'metric_for_best_model': 'eval_loss', 'greater_is_better': False, 'ignore_data_skip': False, 'fsdp': [], 'fsdp_min_num_params': 0, 'fsdp_config': {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, 'tp_size': 0, 'fsdp_transformer_layer_cls_to_wrap': None, 'accelerator_config': {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}, 'deepspeed': None, 'label_smoothing_factor': 0.0, 'optim': 'adamw_torch', 'optim_args': None, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['wandb'], 'ddp_find_unused_parameters': None, 'ddp_bucket_cap_mb': None, 'ddp_broadcast_buffers': None, 'dataloader_pin_memory': True, 'dataloader_persistent_workers': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'hub_model_id': None, 'hub_strategy': 'every_save', 'hub_token': '<HUB_TOKEN>', 'hub_private_repo': None, 'hub_always_push': False, 'gradient_checkpointing': False, 'gradient_checkpointing_kwargs': None, 'include_inputs_for_metrics': False, 'include_for_metrics': [], 'eval_do_concat_batches': True, 'fp16_backend': 'auto', 'push_to_hub_model_id': None, 'push_to_hub_organization': None, 'push_to_hub_token': '<PUSH_TO_HUB_TOKEN>', 'mp_parameters': '', 'auto_find_batch_size': False, 'full_determinism': False, 'torchdynamo': None, 'ray_scope': 'last', 'ddp_timeout': 1800, 'torch_compile': False, 'torch_compile_backend': None, 'torch_compile_mode': None, 'include_tokens_per_second': False, 'include_num_input_tokens_seen': False, 'neftune_noise_alpha': None, 'optim_target_modules': None, 'batch_eval_metrics': False, 'eval_on_start': False, 'use_liger_kernel': False, 'eval_use_gather_object': False, 'average_tokens_across_devices': False}
23
+ 2026-05-03 17:12:15,654 INFO MainThread:63423 [wandb_config.py:__setitem__():154] [no run ID] config set model/num_parameters = 79963661 - <bound method Run._config_callback of <wandb.sdk.wandb_run.Run object at 0x14e776850>>
24
+ 2026-05-03 17:12:15,654 INFO MainThread:63423 [wandb_run.py:_config_callback():1363] config_cb model/num_parameters 79963661 None
25
+ 2026-05-03 17:14:51,017 INFO MsgRouterThr:63423 [mailbox.py:close():129] [no run ID] Closing mailbox, abandoning 1 handles.
wandb/run-20260503_171213-h9m78x54/run-h9m78x54.wandb ADDED
Binary file (24.6 kB). View file