| from __future__ import annotations |
|
|
| import json |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from datasets import Dataset |
| from sklearn.metrics import accuracy_score, f1_score |
| from transformers import AutoTokenizer, Trainer, TrainingArguments |
|
|
| BASE_DIR = Path(__file__).resolve().parent.parent |
| if str(BASE_DIR) not in sys.path: |
| sys.path.insert(0, str(BASE_DIR)) |
|
|
| from config import ( |
| DECISION_PHASE_DIFFICULTY_DATA_DIR, |
| DECISION_PHASE_HEAD_CONFIG, |
| FULL_INTENT_TAXONOMY_DATA_DIR, |
| INTENT_HEAD_CONFIG, |
| INTENT_TYPE_DIFFICULTY_DATA_DIR, |
| MULTITASK_INTENT_MODEL_DIR, |
| SUBTYPE_DIFFICULTY_DATA_DIR, |
| SUBTYPE_HEAD_CONFIG, |
| ) |
| from multitask_model import MultiTaskIntentModel, MultiTaskLabelSizes |
| from training.common import write_json |
|
|
|
|
| IGNORE_INDEX = -100 |
|
|
|
|
| @dataclass |
| class MultiTaskRow: |
| text: str |
| intent_type: int = IGNORE_INDEX |
| intent_subtype: int = IGNORE_INDEX |
| decision_phase: int = IGNORE_INDEX |
|
|
|
|
| def _load_task_rows(path: Path, label_field: str, label2id: dict[str, int]) -> list[tuple[str, int]]: |
| if not path.exists(): |
| return [] |
| rows: list[tuple[str, int]] = [] |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| item = json.loads(line) |
| rows.append((item["text"], label2id[item[label_field]])) |
| return rows |
|
|
|
|
| def _merge_rows( |
| split: str, |
| include_full_intent: bool = True, |
| include_difficulty: bool = True, |
| ) -> list[dict]: |
| merged: dict[str, MultiTaskRow] = {} |
|
|
| def upsert(task_key: str, text: str, label: int) -> None: |
| row = merged.get(text) |
| if row is None: |
| row = MultiTaskRow(text=text) |
| merged[text] = row |
| setattr(row, task_key, int(label)) |
|
|
| |
| for text, label in _load_task_rows( |
| INTENT_HEAD_CONFIG.split_paths[split], |
| INTENT_HEAD_CONFIG.label_field, |
| INTENT_HEAD_CONFIG.label2id, |
| ): |
| upsert("intent_type", text, label) |
| for text, label in _load_task_rows( |
| SUBTYPE_HEAD_CONFIG.split_paths[split], |
| SUBTYPE_HEAD_CONFIG.label_field, |
| SUBTYPE_HEAD_CONFIG.label2id, |
| ): |
| upsert("intent_subtype", text, label) |
| for text, label in _load_task_rows( |
| DECISION_PHASE_HEAD_CONFIG.split_paths[split], |
| DECISION_PHASE_HEAD_CONFIG.label_field, |
| DECISION_PHASE_HEAD_CONFIG.label2id, |
| ): |
| upsert("decision_phase", text, label) |
|
|
| if include_full_intent: |
| full_path = FULL_INTENT_TAXONOMY_DATA_DIR / f"{split}.jsonl" |
| for text, label in _load_task_rows(full_path, "intent_type", INTENT_HEAD_CONFIG.label2id): |
| upsert("intent_type", text, label) |
| for text, label in _load_task_rows(full_path, "intent_subtype", SUBTYPE_HEAD_CONFIG.label2id): |
| upsert("intent_subtype", text, label) |
| for text, label in _load_task_rows(full_path, "decision_phase", DECISION_PHASE_HEAD_CONFIG.label2id): |
| upsert("decision_phase", text, label) |
|
|
| if include_difficulty: |
| for text, label in _load_task_rows( |
| INTENT_TYPE_DIFFICULTY_DATA_DIR / f"{split}.jsonl", |
| "intent_type", |
| INTENT_HEAD_CONFIG.label2id, |
| ): |
| upsert("intent_type", text, label) |
| for text, label in _load_task_rows( |
| SUBTYPE_DIFFICULTY_DATA_DIR / f"{split}.jsonl", |
| "intent_subtype", |
| SUBTYPE_HEAD_CONFIG.label2id, |
| ): |
| upsert("intent_subtype", text, label) |
| for text, label in _load_task_rows( |
| DECISION_PHASE_DIFFICULTY_DATA_DIR / f"{split}.jsonl", |
| "decision_phase", |
| DECISION_PHASE_HEAD_CONFIG.label2id, |
| ): |
| upsert("decision_phase", text, label) |
|
|
| return [ |
| { |
| "text": row.text, |
| "intent_type": row.intent_type, |
| "intent_subtype": row.intent_subtype, |
| "decision_phase": row.decision_phase, |
| } |
| for row in merged.values() |
| ] |
|
|
|
|
| def _prepare_dataset(rows: list[dict], tokenizer, max_length: int) -> Dataset: |
| dataset = Dataset.from_list(rows) |
|
|
| def tokenize(batch): |
| return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=max_length) |
|
|
| dataset = dataset.map(tokenize, batched=True) |
| dataset = dataset.remove_columns(["text"]) |
| dataset.set_format("torch") |
| return dataset |
|
|
|
|
| class MultiTaskTrainer(Trainer): |
| def __init__(self, *args, loss_weights: dict[str, float], **kwargs): |
| super().__init__(*args, **kwargs) |
| self.loss_weights = loss_weights |
|
|
| def _task_ce(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
| """Mean CE over non-ignored labels only. |
| |
| ``CrossEntropyLoss(..., reduction='mean')`` returns NaN when every label in the |
| batch is ``IGNORE_INDEX`` (0 valid targets). Per-row ``reduction='none'`` yields 0 |
| for ignored rows; we then mean over valid rows only, matching standard CE otherwise. |
| """ |
| loss_vec = F.cross_entropy( |
| logits, labels, ignore_index=IGNORE_INDEX, reduction="none" |
| ) |
| valid = labels != IGNORE_INDEX |
| if not valid.any(): |
| return logits.sum() * 0.0 |
| return loss_vec[valid].mean() |
|
|
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| labels_type = inputs.pop("intent_type") |
| labels_subtype = inputs.pop("intent_subtype") |
| labels_phase = inputs.pop("decision_phase") |
| outputs = model(**inputs) |
| loss_type = self._task_ce(outputs["intent_type_logits"], labels_type) |
| loss_subtype = self._task_ce(outputs["intent_subtype_logits"], labels_subtype) |
| loss_phase = self._task_ce(outputs["decision_phase_logits"], labels_phase) |
| loss = ( |
| (self.loss_weights["intent_type"] * loss_type) |
| + (self.loss_weights["intent_subtype"] * loss_subtype) |
| + (self.loss_weights["decision_phase"] * loss_phase) |
| ) |
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| def _masked_metrics(logits: np.ndarray, labels: np.ndarray) -> dict[str, float]: |
| mask = labels != IGNORE_INDEX |
| if not np.any(mask): |
| return {"accuracy": 0.0, "macro_f1": 0.0, "count": 0} |
| preds = np.argmax(logits[mask], axis=-1) |
| true = labels[mask] |
| return { |
| "accuracy": float(accuracy_score(true, preds)), |
| "macro_f1": float(f1_score(true, preds, average="macro")), |
| "count": int(mask.sum()), |
| } |
|
|
|
|
| def _compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| intent_logits, subtype_logits, phase_logits = predictions |
| intent_labels, subtype_labels, phase_labels = labels |
| intent_metrics = _masked_metrics(intent_logits, intent_labels) |
| subtype_metrics = _masked_metrics(subtype_logits, subtype_labels) |
| phase_metrics = _masked_metrics(phase_logits, phase_labels) |
| return { |
| "intent_type_accuracy": intent_metrics["accuracy"], |
| "intent_type_macro_f1": intent_metrics["macro_f1"], |
| "intent_subtype_accuracy": subtype_metrics["accuracy"], |
| "intent_subtype_macro_f1": subtype_metrics["macro_f1"], |
| "decision_phase_accuracy": phase_metrics["accuracy"], |
| "decision_phase_macro_f1": phase_metrics["macro_f1"], |
| } |
|
|
|
|
| def main() -> None: |
| train_rows = _merge_rows("train", include_full_intent=True, include_difficulty=True) |
| val_rows = _merge_rows("val", include_full_intent=True, include_difficulty=True) |
| test_rows = _merge_rows("test", include_full_intent=False, include_difficulty=False) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(INTENT_HEAD_CONFIG.model_name) |
| max_length = max( |
| INTENT_HEAD_CONFIG.max_length, |
| SUBTYPE_HEAD_CONFIG.max_length, |
| DECISION_PHASE_HEAD_CONFIG.max_length, |
| ) |
| train_dataset = _prepare_dataset(train_rows, tokenizer, max_length=max_length) |
| val_dataset = _prepare_dataset(val_rows, tokenizer, max_length=max_length) |
| test_dataset = _prepare_dataset(test_rows, tokenizer, max_length=max_length) |
|
|
| model = MultiTaskIntentModel( |
| INTENT_HEAD_CONFIG.model_name, |
| MultiTaskLabelSizes( |
| intent_type=len(INTENT_HEAD_CONFIG.labels), |
| intent_subtype=len(SUBTYPE_HEAD_CONFIG.labels), |
| decision_phase=len(DECISION_PHASE_HEAD_CONFIG.labels), |
| ), |
| ) |
|
|
| training_args = TrainingArguments( |
| output_dir=str(MULTITASK_INTENT_MODEL_DIR), |
| eval_strategy="epoch", |
| save_strategy="no", |
| logging_strategy="epoch", |
| num_train_epochs=4, |
| per_device_train_batch_size=8, |
| per_device_eval_batch_size=16, |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| report_to="none", |
| label_names=["intent_type", "intent_subtype", "decision_phase"], |
| ) |
| loss_weights = {"intent_type": 1.0, "intent_subtype": 1.0, "decision_phase": 1.0} |
| trainer = MultiTaskTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| compute_metrics=_compute_metrics, |
| loss_weights=loss_weights, |
| ) |
|
|
| print(f"Loaded multitask splits: train={len(train_rows)} val={len(val_rows)} test={len(test_rows)}") |
| trainer.train() |
| val_metrics = trainer.evaluate(eval_dataset=val_dataset, metric_key_prefix="val") |
| test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test") |
| print(val_metrics) |
| print(test_metrics) |
|
|
| MULTITASK_INTENT_MODEL_DIR.mkdir(parents=True, exist_ok=True) |
| tokenizer.save_pretrained(MULTITASK_INTENT_MODEL_DIR) |
| torch.save({"state_dict": model.state_dict()}, MULTITASK_INTENT_MODEL_DIR / "multitask_model.pt") |
| metadata = { |
| "format": "admesh_multitask_intent_v1", |
| "base_model_name": INTENT_HEAD_CONFIG.model_name, |
| "max_length": max_length, |
| "label_maps": { |
| "intent_type": {"label2id": INTENT_HEAD_CONFIG.label2id, "id2label": INTENT_HEAD_CONFIG.id2label}, |
| "intent_subtype": {"label2id": SUBTYPE_HEAD_CONFIG.label2id, "id2label": SUBTYPE_HEAD_CONFIG.id2label}, |
| "decision_phase": {"label2id": DECISION_PHASE_HEAD_CONFIG.label2id, "id2label": DECISION_PHASE_HEAD_CONFIG.id2label}, |
| }, |
| } |
| (MULTITASK_INTENT_MODEL_DIR / "metadata.json").write_text( |
| json.dumps(metadata, indent=2, sort_keys=True) + "\n", |
| encoding="utf-8", |
| ) |
| write_json( |
| MULTITASK_INTENT_MODEL_DIR / "train_metrics.json", |
| { |
| "head": "multitask_intent", |
| "loss_weights": loss_weights, |
| "train_count": len(train_rows), |
| "val_count": len(val_rows), |
| "test_count": len(test_rows), |
| "val_metrics": val_metrics, |
| "test_metrics": test_metrics, |
| }, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|