OliverPerrin commited on
Commit
2261920
·
1 Parent(s): 6553b4f

Fixed Ruff check and small redme update

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. scripts/train_bert_baseline.py +2 -5
README.md CHANGED
@@ -30,7 +30,7 @@ Trained for 8 epochs on an RTX 4070 12GB (~9 hours) with BFloat16 mixed precisio
30
 
31
  ## Key Findings
32
 
33
- From the [research paper](docs/research_paper.tex):
34
 
35
  - **Naive MTL produces mixed results**: topic classification benefits (+3.7% accuracy), but emotion detection suffers negative transfer (−0.02 F1) under mean pooling with round-robin scheduling.
36
  - **Learned attention pooling + temperature sampling eliminates negative transfer entirely**: emotion F1 improves from 0.199 → 0.352 (+77%), surpassing the single-task baseline (0.218).
 
30
 
31
  ## Key Findings
32
 
33
+ From my research paper:
34
 
35
  - **Naive MTL produces mixed results**: topic classification benefits (+3.7% accuracy), but emotion detection suffers negative transfer (−0.02 F1) under mean pooling with round-robin scheduling.
36
  - **Learned attention pooling + temperature sampling eliminates negative transfer entirely**: emotion F1 improves from 0.199 → 0.352 (+77%), surpassing the single-task baseline (0.218).
scripts/train_bert_baseline.py CHANGED
@@ -57,7 +57,6 @@ from src.data.dataset import (
57
  load_emotion_jsonl,
58
  load_topic_jsonl,
59
  )
60
-
61
  from src.training.metrics import (
62
  bootstrap_confidence_interval,
63
  multilabel_f1,
@@ -67,7 +66,6 @@ from src.training.metrics import (
67
  tune_per_class_thresholds,
68
  )
69
 
70
-
71
  # Configuration
72
 
73
  @dataclass
@@ -439,7 +437,6 @@ class BertTrainer:
439
  self.optimizer.zero_grad()
440
 
441
  epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
442
- epoch_metrics: Dict[str, List[float]] = {}
443
 
444
  if len(self.train_loaders) > 1:
445
  # Multi-task: temperature sampling
@@ -951,12 +948,12 @@ def run_experiment(mode: str, config: BertBaselineConfig) -> Dict[str, Any]:
951
  # Load best checkpoint for final evaluation
952
  best_path = config.checkpoint_dir / mode / "best.pt"
953
  if best_path.exists():
954
- print(f"\n Loading best checkpoint for final evaluation...")
955
  checkpoint = torch.load(best_path, map_location=device, weights_only=False)
956
  model.load_state_dict(checkpoint["model_state_dict"])
957
 
958
  # Full evaluation
959
- print(f"\n Running final evaluation...")
960
  eval_results = evaluate_bert_model(
961
  model,
962
  val_loaders,
 
57
  load_emotion_jsonl,
58
  load_topic_jsonl,
59
  )
 
60
  from src.training.metrics import (
61
  bootstrap_confidence_interval,
62
  multilabel_f1,
 
66
  tune_per_class_thresholds,
67
  )
68
 
 
69
  # Configuration
70
 
71
  @dataclass
 
437
  self.optimizer.zero_grad()
438
 
439
  epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
 
440
 
441
  if len(self.train_loaders) > 1:
442
  # Multi-task: temperature sampling
 
948
  # Load best checkpoint for final evaluation
949
  best_path = config.checkpoint_dir / mode / "best.pt"
950
  if best_path.exists():
951
+ print("\n Loading best checkpoint for final evaluation...")
952
  checkpoint = torch.load(best_path, map_location=device, weights_only=False)
953
  model.load_state_dict(checkpoint["model_state_dict"])
954
 
955
  # Full evaluation
956
+ print("\n Running final evaluation...")
957
  eval_results = evaluate_bert_model(
958
  model,
959
  val_loaders,