RoBERTa-EmoWOZ / README_for_train_script.md
joshthoo's picture
add training files
10de798
# RoBERTa Dialogue Sentiment Analysis β€” EmoWOZ Fine-Tuning
Fine-tunes `roberta-base` (or `roberta-large`) on the EmoWOZ dataset for
**per-utterance emotion classification** in task-oriented dialogue, using a
configurable sliding window of preceding dialogue history as context.
## Emotion Labels
| ID | Label | Notes |
|----|--------------|------------------------------|
| -1 | system | Filtered out β€” not predicted |
| 0 | neutral | |
| 1 | fearful | |
| 2 | dissatisfied | |
| 3 | apologetic | |
| 4 | abusive | |
| 5 | excited | |
| 6 | satisfied | |
Only **user utterances** (emotion β‰  -1) are classified. System turns are
retained as context but not predicted.
## Project Layout
```
roberta_emowoz/
β”œβ”€β”€ configs/
β”‚ └── default.yaml # All hyperparameters (edit this)
β”œβ”€β”€ data/
β”‚ β”œβ”€β”€ dataset.py # DialogueDataset + collator
β”‚ └── preprocessing.py # JSON β†’ flat utterance samples
β”œβ”€β”€ models/
β”‚ β”œβ”€β”€ model.py # RoBERTa wrapper with classification head
β”‚ └── focal_loss.py # Class-imbalance-aware loss
β”œβ”€β”€ scripts/
β”‚ β”œβ”€β”€ train.py # Main training entry point
β”‚ β”œβ”€β”€ evaluate.py # Full eval with per-class metrics
β”‚ └── predict.py # Interactive / batch inference
β”œβ”€β”€ outputs/ # Checkpoints, logs, predictions (gitignored)
└── requirements.txt
```
## Quick Start
NEED CONDA, install it first! research with chatgpt if u need to know what it is
```bash
conda create -n nst_v4 python=3.11
conda activate nst_v4
pip install -r requirements_normal.txt
pip install --index-url https://download.pytorch.org/whl/cu121 -r requirements_torch.txt
# 2. Place your data files in the project root (or update config paths)
# set1_train.json set1_val.json set1_test.json
# 3. Train (uses defaults from DEFAULT_CONFIG in train.py, override any value)
# For my RTX 3070 this takes 2 hours to complete
python train.py
# Or with custom parameters:
python train.py --epochs 10 --batch_size 32 --history_window 4 --loss focal
# 4. Evaluate on test set
# python evaluate.py --checkpoint outputs/best_model
# 5. Interactive prediction
# python predict.py --checkpoint outputs/best_model --history_window 3
```
## Key Hyperparameter β€” `history_window`
`history_window` controls how many **preceding turns** (both user and system)
are prepended as context before the current utterance.
```
history_window = 0 β†’ [CLS] <current utterance> [SEP]
history_window = 2 β†’ [CLS] <turn-2> [SEP] <turn-1> [SEP] <current> [SEP]
history_window = 4 β†’ [CLS] <t-4> [SEP] <t-3> [SEP] <t-2> [SEP] <t-1> [SEP] <current> [SEP]
```
Turns are ordered oldest β†’ newest. System turns are prefixed with `"SYS:"`,
user turns with `"USR:"` to give the model speaker role signals.
Recommended sweep: `[0, 2, 4, 6]`.
## Class Imbalance
EmoWOZ is heavily skewed toward class 0 (neutral). Two mitigation strategies
are included and can be toggled in `configs/default.yaml`:
- **Focal Loss** (`loss: focal`) β€” down-weights easy neutral examples.
- **Weighted Cross-Entropy** (`loss: weighted_ce`) β€” per-class inverse
frequency weights computed from the training set.
## Outputs
After training, `outputs/` contains:
- `best_model/` β€” best checkpoint by macro-F1 on validation
- `last_model/` β€” final epoch checkpoint
- `training_log.jsonl` β€” epoch-level metrics
- `test_results.json` β€” per-class precision / recall / F1 + confusion matrix