Spaces:
Sleeping
Sleeping
OliverPerrin commited on
Commit ·
2261920
1
Parent(s): 6553b4f
Fixed Ruff check and small redme update
Browse files- README.md +1 -1
- scripts/train_bert_baseline.py +2 -5
README.md
CHANGED
|
@@ -30,7 +30,7 @@ Trained for 8 epochs on an RTX 4070 12GB (~9 hours) with BFloat16 mixed precisio
|
|
| 30 |
|
| 31 |
## Key Findings
|
| 32 |
|
| 33 |
-
From
|
| 34 |
|
| 35 |
- **Naive MTL produces mixed results**: topic classification benefits (+3.7% accuracy), but emotion detection suffers negative transfer (−0.02 F1) under mean pooling with round-robin scheduling.
|
| 36 |
- **Learned attention pooling + temperature sampling eliminates negative transfer entirely**: emotion F1 improves from 0.199 → 0.352 (+77%), surpassing the single-task baseline (0.218).
|
|
|
|
| 30 |
|
| 31 |
## Key Findings
|
| 32 |
|
| 33 |
+
From my research paper:
|
| 34 |
|
| 35 |
- **Naive MTL produces mixed results**: topic classification benefits (+3.7% accuracy), but emotion detection suffers negative transfer (−0.02 F1) under mean pooling with round-robin scheduling.
|
| 36 |
- **Learned attention pooling + temperature sampling eliminates negative transfer entirely**: emotion F1 improves from 0.199 → 0.352 (+77%), surpassing the single-task baseline (0.218).
|
scripts/train_bert_baseline.py
CHANGED
|
@@ -57,7 +57,6 @@ from src.data.dataset import (
|
|
| 57 |
load_emotion_jsonl,
|
| 58 |
load_topic_jsonl,
|
| 59 |
)
|
| 60 |
-
|
| 61 |
from src.training.metrics import (
|
| 62 |
bootstrap_confidence_interval,
|
| 63 |
multilabel_f1,
|
|
@@ -67,7 +66,6 @@ from src.training.metrics import (
|
|
| 67 |
tune_per_class_thresholds,
|
| 68 |
)
|
| 69 |
|
| 70 |
-
|
| 71 |
# Configuration
|
| 72 |
|
| 73 |
@dataclass
|
|
@@ -439,7 +437,6 @@ class BertTrainer:
|
|
| 439 |
self.optimizer.zero_grad()
|
| 440 |
|
| 441 |
epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
|
| 442 |
-
epoch_metrics: Dict[str, List[float]] = {}
|
| 443 |
|
| 444 |
if len(self.train_loaders) > 1:
|
| 445 |
# Multi-task: temperature sampling
|
|
@@ -951,12 +948,12 @@ def run_experiment(mode: str, config: BertBaselineConfig) -> Dict[str, Any]:
|
|
| 951 |
# Load best checkpoint for final evaluation
|
| 952 |
best_path = config.checkpoint_dir / mode / "best.pt"
|
| 953 |
if best_path.exists():
|
| 954 |
-
print(
|
| 955 |
checkpoint = torch.load(best_path, map_location=device, weights_only=False)
|
| 956 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 957 |
|
| 958 |
# Full evaluation
|
| 959 |
-
print(
|
| 960 |
eval_results = evaluate_bert_model(
|
| 961 |
model,
|
| 962 |
val_loaders,
|
|
|
|
| 57 |
load_emotion_jsonl,
|
| 58 |
load_topic_jsonl,
|
| 59 |
)
|
|
|
|
| 60 |
from src.training.metrics import (
|
| 61 |
bootstrap_confidence_interval,
|
| 62 |
multilabel_f1,
|
|
|
|
| 66 |
tune_per_class_thresholds,
|
| 67 |
)
|
| 68 |
|
|
|
|
| 69 |
# Configuration
|
| 70 |
|
| 71 |
@dataclass
|
|
|
|
| 437 |
self.optimizer.zero_grad()
|
| 438 |
|
| 439 |
epoch_losses: Dict[str, List[float]] = {t: [] for t in self.train_loaders}
|
|
|
|
| 440 |
|
| 441 |
if len(self.train_loaders) > 1:
|
| 442 |
# Multi-task: temperature sampling
|
|
|
|
| 948 |
# Load best checkpoint for final evaluation
|
| 949 |
best_path = config.checkpoint_dir / mode / "best.pt"
|
| 950 |
if best_path.exists():
|
| 951 |
+
print("\n Loading best checkpoint for final evaluation...")
|
| 952 |
checkpoint = torch.load(best_path, map_location=device, weights_only=False)
|
| 953 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 954 |
|
| 955 |
# Full evaluation
|
| 956 |
+
print("\n Running final evaluation...")
|
| 957 |
eval_results = evaluate_bert_model(
|
| 958 |
model,
|
| 959 |
val_loaders,
|