gr8monk3ys commited on
Commit
2be4558
·
verified ·
1 Parent(s): ca0268b

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +146 -0
  2. inference.py +225 -0
  3. requirements.txt +7 -0
  4. train.py +453 -0
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model: distilbert-base-uncased
4
+ tags:
5
+ - text-classification
6
+ - arxiv
7
+ - academic-papers
8
+ - distilbert
9
+ datasets:
10
+ - ccdv/arxiv-classification
11
+ metrics:
12
+ - accuracy
13
+ - f1
14
+ pipeline_tag: text-classification
15
+ ---
16
+
17
+ # Academic Paper Classifier
18
+
19
+ A DistilBERT model fine-tuned to classify academic paper abstracts into arxiv
20
+ subject categories. Given the abstract of a research paper, the model predicts
21
+ which area of computer science or statistics the paper belongs to.
22
+
23
+ ## Intended Use
24
+
25
+ This model is designed for:
26
+
27
+ - **Automated paper triage** -- quickly routing new submissions to the
28
+ appropriate reviewers or reading lists.
29
+ - **Literature search** -- filtering large collections of papers by
30
+ predicted subject area.
31
+ - **Research tooling** -- as a building block in larger academic-paper
32
+ analysis pipelines.
33
+
34
+ The model is **not** intended for high-stakes decisions such as publication
35
+ acceptance or funding allocation.
36
+
37
+ ## Labels
38
+
39
+ | Id | Label | Description |
40
+ |----|----------|-----------------------------------|
41
+ | 0 | cs.AI | Artificial Intelligence |
42
+ | 1 | cs.CL | Computation and Language (NLP) |
43
+ | 2 | cs.CV | Computer Vision |
44
+ | 3 | cs.LG | Machine Learning |
45
+ | 4 | cs.NE | Neural and Evolutionary Computing |
46
+ | 5 | cs.RO | Robotics |
47
+ | 6 | math.ST | Statistics Theory |
48
+ | 7 | stat.ML | Machine Learning (Statistics) |
49
+
50
+ ## Training Procedure
51
+
52
+ ### Base Model
53
+
54
+ [`distilbert-base-uncased`](https://huggingface.co/distilbert-base-uncased) --
55
+ a distilled version of BERT that is 60% faster while retaining 97% of BERT's
56
+ language-understanding performance.
57
+
58
+ ### Dataset
59
+
60
+ [`ccdv/arxiv-classification`](https://huggingface.co/datasets/ccdv/arxiv-classification)
61
+ -- a curated collection of arxiv paper abstracts with subject category labels.
62
+
63
+ ### Hyperparameters
64
+
65
+ | Parameter | Value |
66
+ |------------------------|--------|
67
+ | Learning rate | 2e-5 |
68
+ | LR scheduler | Linear with warmup |
69
+ | Warmup ratio | 0.1 |
70
+ | Weight decay | 0.01 |
71
+ | Epochs | 5 |
72
+ | Batch size (train) | 16 |
73
+ | Batch size (eval) | 32 |
74
+ | Max sequence length | 512 |
75
+ | Early stopping patience| 3 |
76
+ | Seed | 42 |
77
+
78
+ ### Metrics
79
+
80
+ The model is evaluated on accuracy, weighted F1, weighted precision, and
81
+ weighted recall. The best checkpoint is selected by weighted F1.
82
+
83
+ ## How to Use
84
+
85
+ ### With the `transformers` pipeline
86
+
87
+ ```python
88
+ from transformers import pipeline
89
+
90
+ classifier = pipeline(
91
+ "text-classification",
92
+ model="gr8monk3ys/paper-classifier-model",
93
+ )
94
+
95
+ abstract = (
96
+ "We introduce a new method for neural machine translation that uses "
97
+ "attention mechanisms to align source and target sentences, achieving "
98
+ "state-of-the-art results on WMT benchmarks."
99
+ )
100
+
101
+ result = classifier(abstract)
102
+ print(result)
103
+ # [{'label': 'cs.CL', 'score': 0.95}]
104
+ ```
105
+
106
+ ### With the included inference script
107
+
108
+ ```bash
109
+ python inference.py \
110
+ --model_path gr8monk3ys/paper-classifier-model \
111
+ --abstract "We propose a convolutional neural network for image recognition..."
112
+ ```
113
+
114
+ ### Training from scratch
115
+
116
+ ```bash
117
+ pip install -r requirements.txt
118
+
119
+ python train.py \
120
+ --num_train_epochs 5 \
121
+ --learning_rate 2e-5 \
122
+ --per_device_train_batch_size 16 \
123
+ --push_to_hub
124
+ ```
125
+
126
+ ## Limitations
127
+
128
+ - The model only covers a fixed set of 8 arxiv categories. Papers from other
129
+ fields will be forced into one of these buckets.
130
+ - Performance may degrade on abstracts that are unusually short, written in a
131
+ language other than English, or that span multiple subject areas.
132
+ - The model inherits any biases present in the DistilBERT base weights and in
133
+ the training dataset.
134
+
135
+ ## Citation
136
+
137
+ If you use this model in your research, please cite:
138
+
139
+ ```bibtex
140
+ @misc{scaturchio2025paperclassifier,
141
+ title = {Academic Paper Classifier},
142
+ author = {Lorenzo Scaturchio},
143
+ year = {2025},
144
+ url = {https://huggingface.co/gr8monk3ys/paper-classifier-model}
145
+ }
146
+ ```
inference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference script for the Academic Paper Classifier.
4
+
5
+ Loads a fine-tuned DistilBERT model and predicts the arxiv category for a
6
+ given paper abstract. Returns the predicted category along with per-class
7
+ confidence scores.
8
+
9
+ Usage examples:
10
+ # Use a local model directory
11
+ python inference.py --model_path ./model --abstract "We propose a novel ..."
12
+
13
+ # Use a HuggingFace Hub model
14
+ python inference.py --model_path gr8monk3ys/paper-classifier-model \
15
+ --abstract "We propose a novel ..."
16
+
17
+ # Interactive mode (reads from stdin)
18
+ python inference.py --model_path ./model
19
+
20
+ Author: Lorenzo Scaturchio (gr8monk3ys)
21
+ License: MIT
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import logging
27
+ import sys
28
+ from pathlib import Path
29
+
30
+ import torch
31
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Logging
35
+ # ---------------------------------------------------------------------------
36
+ logging.basicConfig(
37
+ level=logging.INFO,
38
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
39
+ handlers=[logging.StreamHandler(sys.stdout)],
40
+ )
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Classifier wrapper
46
+ # ---------------------------------------------------------------------------
47
+ class PaperClassifier:
48
+ """Thin wrapper around a fine-tuned sequence-classification model.
49
+
50
+ Parameters
51
+ ----------
52
+ model_path : str
53
+ Path to a local model directory **or** a HuggingFace Hub model id.
54
+ device : str | None
55
+ Target device (``"cpu"``, ``"cuda"``, ``"mps"``). If *None* the best
56
+ available device is selected automatically.
57
+ """
58
+
59
+ def __init__(self, model_path: str, device: str | None = None) -> None:
60
+ if device is None:
61
+ if torch.cuda.is_available():
62
+ device = "cuda"
63
+ elif torch.backends.mps.is_available():
64
+ device = "mps"
65
+ else:
66
+ device = "cpu"
67
+ self.device = torch.device(device)
68
+
69
+ logger.info("Loading tokenizer from: %s", model_path)
70
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
71
+
72
+ logger.info("Loading model from: %s", model_path)
73
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
74
+ self.model.to(self.device)
75
+
76
+ # Read label mapping stored in the model config
77
+ self.id2label: dict[int, str] = self.model.config.id2label
78
+ logger.info("Labels: %s", list(self.id2label.values()))
79
+
80
+ @torch.no_grad()
81
+ def predict(self, abstract: str, top_k: int | None = None) -> dict:
82
+ """Classify a single paper abstract.
83
+
84
+ Parameters
85
+ ----------
86
+ abstract : str
87
+ The paper abstract to classify.
88
+ top_k : int | None
89
+ If given, only the *top_k* categories (by confidence) are returned
90
+ in ``scores``. Pass *None* to return all categories.
91
+
92
+ Returns
93
+ -------
94
+ dict
95
+ ``{"label": str, "confidence": float, "scores": {label: prob}}``
96
+ """
97
+ self.model.eval()
98
+
99
+ inputs = self.tokenizer(
100
+ abstract,
101
+ return_tensors="pt",
102
+ truncation=True,
103
+ padding=True,
104
+ max_length=512,
105
+ ).to(self.device)
106
+
107
+ logits = self.model(**inputs).logits
108
+ probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy()
109
+
110
+ sorted_indices = probs.argsort()[::-1]
111
+ if top_k is not None:
112
+ sorted_indices = sorted_indices[:top_k]
113
+
114
+ scores = {
115
+ self.id2label[int(idx)]: float(probs[idx]) for idx in sorted_indices
116
+ }
117
+
118
+ best_idx = int(probs.argmax())
119
+ return {
120
+ "label": self.id2label[best_idx],
121
+ "confidence": float(probs[best_idx]),
122
+ "scores": scores,
123
+ }
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # CLI
128
+ # ---------------------------------------------------------------------------
129
+ def parse_args() -> argparse.Namespace:
130
+ parser = argparse.ArgumentParser(
131
+ description="Classify an academic paper abstract into an arxiv category."
132
+ )
133
+ parser.add_argument(
134
+ "--model_path",
135
+ type=str,
136
+ default="./model",
137
+ help="Path to the fine-tuned model directory or HF Hub id (default: %(default)s).",
138
+ )
139
+ parser.add_argument(
140
+ "--abstract",
141
+ type=str,
142
+ default=None,
143
+ help="Paper abstract text. If omitted, the script enters interactive mode.",
144
+ )
145
+ parser.add_argument(
146
+ "--top_k",
147
+ type=int,
148
+ default=None,
149
+ help="Only show the top-k predictions (default: show all).",
150
+ )
151
+ parser.add_argument(
152
+ "--device",
153
+ type=str,
154
+ default=None,
155
+ choices=["cpu", "cuda", "mps"],
156
+ help="Device to run inference on (default: auto-detect).",
157
+ )
158
+ parser.add_argument(
159
+ "--json",
160
+ action="store_true",
161
+ default=False,
162
+ dest="output_json",
163
+ help="Output raw JSON instead of human-readable text.",
164
+ )
165
+ return parser.parse_args()
166
+
167
+
168
+ def _print_result(result: dict, output_json: bool) -> None:
169
+ """Pretty-print or JSON-dump a prediction result."""
170
+ if output_json:
171
+ print(json.dumps(result, indent=2))
172
+ return
173
+
174
+ print(f"\n Predicted category : {result['label']}")
175
+ print(f" Confidence : {result['confidence']:.4f}")
176
+ print(" ---------------------------------")
177
+ for label, score in result["scores"].items():
178
+ bar = "#" * int(score * 40)
179
+ print(f" {label:<10s} {score:6.4f} {bar}")
180
+ print()
181
+
182
+
183
+ def main() -> None:
184
+ args = parse_args()
185
+ classifier = PaperClassifier(model_path=args.model_path, device=args.device)
186
+
187
+ if args.abstract is not None:
188
+ result = classifier.predict(args.abstract, top_k=args.top_k)
189
+ _print_result(result, args.output_json)
190
+ return
191
+
192
+ # Interactive mode
193
+ print("Academic Paper Classifier - Interactive Mode")
194
+ print("Enter a paper abstract (or 'quit' to exit).")
195
+ print("For multi-line input, end with an empty line.\n")
196
+
197
+ while True:
198
+ try:
199
+ lines: list[str] = []
200
+ prompt = "abstract> " if sys.stdin.isatty() else ""
201
+ while True:
202
+ line = input(prompt)
203
+ if line.strip().lower() == "quit":
204
+ logger.info("Exiting.")
205
+ return
206
+ if line == "" and lines:
207
+ break
208
+ lines.append(line)
209
+ prompt = "... " if sys.stdin.isatty() else ""
210
+
211
+ abstract = " ".join(lines).strip()
212
+ if not abstract:
213
+ continue
214
+
215
+ result = classifier.predict(abstract, top_k=args.top_k)
216
+ _print_result(result, args.output_json)
217
+
218
+ except (EOFError, KeyboardInterrupt):
219
+ print()
220
+ logger.info("Exiting.")
221
+ return
222
+
223
+
224
+ if __name__ == "__main__":
225
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers>=4.36.0
2
+ datasets>=2.16.0
3
+ torch>=2.1.0
4
+ scikit-learn>=1.3.0
5
+ accelerate>=0.25.0
6
+ evaluate>=0.4.0
7
+ huggingface_hub>=0.20.0
train.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Fine-tune DistilBERT for academic paper abstract classification.
4
+
5
+ This script downloads arxiv paper abstracts, preprocesses them, and fine-tunes
6
+ a DistilBERT model for multi-class sequence classification. Supports pushing
7
+ the trained model to the HuggingFace Hub.
8
+
9
+ Author: Lorenzo Scaturchio (gr8monk3ys)
10
+ License: MIT
11
+ """
12
+
13
+ import argparse
14
+ import logging
15
+ import os
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ import evaluate
20
+ import numpy as np
21
+ import torch
22
+ from datasets import ClassLabel, DatasetDict, load_dataset
23
+ from transformers import (
24
+ AutoModelForSequenceClassification,
25
+ AutoTokenizer,
26
+ EarlyStoppingCallback,
27
+ Trainer,
28
+ TrainingArguments,
29
+ set_seed,
30
+ )
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Logging
34
+ # ---------------------------------------------------------------------------
35
+ logging.basicConfig(
36
+ level=logging.INFO,
37
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
38
+ handlers=[logging.StreamHandler(sys.stdout)],
39
+ )
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Constants
44
+ # ---------------------------------------------------------------------------
45
+ MODEL_NAME = "distilbert-base-uncased"
46
+ DEFAULT_DATASET = "ccdv/arxiv-classification"
47
+ DEFAULT_OUTPUT_DIR = "./results"
48
+ DEFAULT_MODEL_DIR = "./model"
49
+
50
+ # Canonical label order so the id<->label mapping is deterministic.
51
+ LABEL_NAMES = [
52
+ "cs.AI",
53
+ "cs.CL",
54
+ "cs.CV",
55
+ "cs.LG",
56
+ "cs.NE",
57
+ "cs.RO",
58
+ "math.ST",
59
+ "stat.ML",
60
+ ]
61
+
62
+
63
+ # ---------------------------------------------------------------------------
64
+ # Helpers
65
+ # ---------------------------------------------------------------------------
66
+ def parse_args() -> argparse.Namespace:
67
+ """Parse command-line arguments for training hyperparameters."""
68
+ parser = argparse.ArgumentParser(
69
+ description="Fine-tune DistilBERT on arxiv paper classification."
70
+ )
71
+
72
+ # Data
73
+ parser.add_argument(
74
+ "--dataset_name",
75
+ type=str,
76
+ default=DEFAULT_DATASET,
77
+ help="HuggingFace dataset identifier (default: %(default)s).",
78
+ )
79
+ parser.add_argument(
80
+ "--max_length",
81
+ type=int,
82
+ default=512,
83
+ help="Maximum token length for the tokenizer (default: %(default)s).",
84
+ )
85
+ parser.add_argument(
86
+ "--max_train_samples",
87
+ type=int,
88
+ default=None,
89
+ help="Cap the number of training samples (useful for debugging).",
90
+ )
91
+ parser.add_argument(
92
+ "--max_eval_samples",
93
+ type=int,
94
+ default=None,
95
+ help="Cap the number of evaluation samples (useful for debugging).",
96
+ )
97
+
98
+ # Training
99
+ parser.add_argument(
100
+ "--output_dir",
101
+ type=str,
102
+ default=DEFAULT_OUTPUT_DIR,
103
+ help="Directory for training checkpoints (default: %(default)s).",
104
+ )
105
+ parser.add_argument(
106
+ "--model_dir",
107
+ type=str,
108
+ default=DEFAULT_MODEL_DIR,
109
+ help="Directory where the final model is saved (default: %(default)s).",
110
+ )
111
+ parser.add_argument(
112
+ "--num_train_epochs",
113
+ type=int,
114
+ default=5,
115
+ help="Total number of training epochs (default: %(default)s).",
116
+ )
117
+ parser.add_argument(
118
+ "--per_device_train_batch_size",
119
+ type=int,
120
+ default=16,
121
+ help="Batch size per device during training (default: %(default)s).",
122
+ )
123
+ parser.add_argument(
124
+ "--per_device_eval_batch_size",
125
+ type=int,
126
+ default=32,
127
+ help="Batch size per device during evaluation (default: %(default)s).",
128
+ )
129
+ parser.add_argument(
130
+ "--learning_rate",
131
+ type=float,
132
+ default=2e-5,
133
+ help="Peak learning rate (default: %(default)s).",
134
+ )
135
+ parser.add_argument(
136
+ "--weight_decay",
137
+ type=float,
138
+ default=0.01,
139
+ help="Weight decay coefficient (default: %(default)s).",
140
+ )
141
+ parser.add_argument(
142
+ "--warmup_ratio",
143
+ type=float,
144
+ default=0.1,
145
+ help="Fraction of total steps used for linear warmup (default: %(default)s).",
146
+ )
147
+ parser.add_argument(
148
+ "--seed",
149
+ type=int,
150
+ default=42,
151
+ help="Random seed for reproducibility (default: %(default)s).",
152
+ )
153
+ parser.add_argument(
154
+ "--early_stopping_patience",
155
+ type=int,
156
+ default=3,
157
+ help="Number of evaluations with no improvement before stopping (default: %(default)s).",
158
+ )
159
+ parser.add_argument(
160
+ "--fp16",
161
+ action="store_true",
162
+ default=False,
163
+ help="Use mixed-precision (FP16) training.",
164
+ )
165
+
166
+ # Hub
167
+ parser.add_argument(
168
+ "--push_to_hub",
169
+ action="store_true",
170
+ default=False,
171
+ help="Push the trained model to the HuggingFace Hub.",
172
+ )
173
+ parser.add_argument(
174
+ "--hub_model_id",
175
+ type=str,
176
+ default="gr8monk3ys/paper-classifier-model",
177
+ help="Repository id on the HuggingFace Hub (default: %(default)s).",
178
+ )
179
+
180
+ return parser.parse_args()
181
+
182
+
183
+ def build_label_mappings(label_names: list[str]) -> tuple[dict, dict]:
184
+ """Return (label2id, id2label) dicts for the given label names."""
185
+ label2id = {label: idx for idx, label in enumerate(label_names)}
186
+ id2label = {idx: label for idx, label in enumerate(label_names)}
187
+ return label2id, id2label
188
+
189
+
190
+ def load_and_prepare_dataset(
191
+ dataset_name: str,
192
+ label2id: dict[str, int],
193
+ max_train_samples: int | None = None,
194
+ max_eval_samples: int | None = None,
195
+ ) -> DatasetDict:
196
+ """Load the dataset and normalise the label column.
197
+
198
+ The function handles two common dataset layouts:
199
+ 1. The dataset already has train / validation / test splits and a
200
+ numeric ``label`` column whose values match our ``label2id``.
201
+ 2. The dataset has a string ``label`` column that needs mapping.
202
+
203
+ Returns a ``DatasetDict`` with ``train`` and ``validation`` splits.
204
+ """
205
+ logger.info("Loading dataset: %s", dataset_name)
206
+ raw = load_dataset(dataset_name, trust_remote_code=True)
207
+
208
+ # Determine the text and label column names --------------------------
209
+ sample_columns = list(next(iter(raw.values())).column_names)
210
+ text_col = None
211
+ for candidate in ("text", "abstract", "input", "sentence"):
212
+ if candidate in sample_columns:
213
+ text_col = candidate
214
+ break
215
+ if text_col is None:
216
+ # Fall back to the first string-typed column
217
+ text_col = sample_columns[0]
218
+ logger.info("Using text column: '%s'", text_col)
219
+
220
+ label_col = None
221
+ for candidate in ("label", "labels", "category", "class"):
222
+ if candidate in sample_columns:
223
+ label_col = candidate
224
+ break
225
+ if label_col is None:
226
+ label_col = sample_columns[-1]
227
+ logger.info("Using label column: '%s'", label_col)
228
+
229
+ # Rename columns so downstream code can rely on 'text' and 'label' ---
230
+ def _rename(example):
231
+ return {"text": str(example[text_col]), "label": example[label_col]}
232
+
233
+ raw = raw.map(_rename, remove_columns=sample_columns)
234
+
235
+ # If labels are strings, map them to ints using label2id -------------
236
+ sample_label = raw[list(raw.keys())[0]][0]["label"]
237
+ if isinstance(sample_label, str):
238
+ logger.info("Mapping string labels to integer ids.")
239
+
240
+ def _map_label(example):
241
+ lbl = example["label"]
242
+ if lbl in label2id:
243
+ example["label"] = label2id[lbl]
244
+ else:
245
+ example["label"] = -1 # will be filtered out
246
+ return example
247
+
248
+ raw = raw.map(_map_label)
249
+ raw = raw.filter(lambda ex: ex["label"] != -1)
250
+
251
+ # Ensure we have a ClassLabel feature --------------------------------
252
+ label_feature = ClassLabel(
253
+ num_classes=len(label2id), names=list(label2id.keys())
254
+ )
255
+ raw = raw.cast_column("label", label_feature)
256
+
257
+ # Build train / validation splits ------------------------------------
258
+ if "validation" not in raw and "test" in raw:
259
+ raw["validation"] = raw.pop("test")
260
+ elif "validation" not in raw:
261
+ split = raw["train"].train_test_split(test_size=0.1, seed=42, stratify_by_column="label")
262
+ raw = DatasetDict({"train": split["train"], "validation": split["test"]})
263
+
264
+ # Subsample if requested ---------------------------------------------
265
+ if max_train_samples is not None:
266
+ raw["train"] = raw["train"].select(range(min(max_train_samples, len(raw["train"]))))
267
+ if max_eval_samples is not None:
268
+ raw["validation"] = raw["validation"].select(
269
+ range(min(max_eval_samples, len(raw["validation"])))
270
+ )
271
+
272
+ logger.info(
273
+ "Dataset sizes -> train: %d, validation: %d",
274
+ len(raw["train"]),
275
+ len(raw["validation"]),
276
+ )
277
+ return raw
278
+
279
+
280
+ def tokenize_dataset(
281
+ dataset: DatasetDict,
282
+ tokenizer: AutoTokenizer,
283
+ max_length: int,
284
+ ) -> DatasetDict:
285
+ """Tokenize the ``text`` column using the supplied tokenizer."""
286
+
287
+ def _tokenize(batch):
288
+ return tokenizer(
289
+ batch["text"],
290
+ padding="max_length",
291
+ truncation=True,
292
+ max_length=max_length,
293
+ )
294
+
295
+ logger.info("Tokenizing dataset (max_length=%d) ...", max_length)
296
+ tokenized = dataset.map(_tokenize, batched=True, desc="Tokenizing")
297
+ tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
298
+ return tokenized
299
+
300
+
301
+ def build_compute_metrics_fn():
302
+ """Return a ``compute_metrics`` callable for the HF Trainer.
303
+
304
+ Loads the ``accuracy``, ``f1``, ``precision`` and ``recall`` evaluate
305
+ metrics once at creation time to avoid repeated disk access.
306
+ """
307
+ acc_metric = evaluate.load("accuracy")
308
+ f1_metric = evaluate.load("f1")
309
+ prec_metric = evaluate.load("precision")
310
+ rec_metric = evaluate.load("recall")
311
+
312
+ def compute_metrics(eval_pred):
313
+ logits, labels = eval_pred
314
+ predictions = np.argmax(logits, axis=-1)
315
+ results = {}
316
+ results.update(acc_metric.compute(predictions=predictions, references=labels))
317
+ results.update(
318
+ f1_metric.compute(
319
+ predictions=predictions, references=labels, average="weighted"
320
+ )
321
+ )
322
+ results.update(
323
+ prec_metric.compute(
324
+ predictions=predictions, references=labels, average="weighted"
325
+ )
326
+ )
327
+ results.update(
328
+ rec_metric.compute(
329
+ predictions=predictions, references=labels, average="weighted"
330
+ )
331
+ )
332
+ return results
333
+
334
+ return compute_metrics
335
+
336
+
337
+ # ---------------------------------------------------------------------------
338
+ # Main
339
+ # ---------------------------------------------------------------------------
340
+ def main() -> None:
341
+ args = parse_args()
342
+
343
+ # Reproducibility
344
+ set_seed(args.seed)
345
+ logger.info("Seed set to %d", args.seed)
346
+
347
+ # Device info
348
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
349
+ logger.info("Using device: %s", device)
350
+
351
+ # Label mappings
352
+ label2id, id2label = build_label_mappings(LABEL_NAMES)
353
+ num_labels = len(LABEL_NAMES)
354
+ logger.info("Number of labels: %d", num_labels)
355
+
356
+ # Dataset
357
+ dataset = load_and_prepare_dataset(
358
+ dataset_name=args.dataset_name,
359
+ label2id=label2id,
360
+ max_train_samples=args.max_train_samples,
361
+ max_eval_samples=args.max_eval_samples,
362
+ )
363
+
364
+ # Tokenizer
365
+ logger.info("Loading tokenizer: %s", MODEL_NAME)
366
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
367
+ tokenized_dataset = tokenize_dataset(dataset, tokenizer, args.max_length)
368
+
369
+ # Model
370
+ logger.info("Loading model: %s", MODEL_NAME)
371
+ model = AutoModelForSequenceClassification.from_pretrained(
372
+ MODEL_NAME,
373
+ num_labels=num_labels,
374
+ id2label=id2label,
375
+ label2id=label2id,
376
+ )
377
+
378
+ # Training arguments
379
+ training_args = TrainingArguments(
380
+ output_dir=args.output_dir,
381
+ num_train_epochs=args.num_train_epochs,
382
+ per_device_train_batch_size=args.per_device_train_batch_size,
383
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
384
+ learning_rate=args.learning_rate,
385
+ weight_decay=args.weight_decay,
386
+ warmup_ratio=args.warmup_ratio,
387
+ lr_scheduler_type="linear",
388
+ eval_strategy="epoch",
389
+ save_strategy="epoch",
390
+ logging_strategy="steps",
391
+ logging_steps=50,
392
+ save_total_limit=2,
393
+ load_best_model_at_end=True,
394
+ metric_for_best_model="f1",
395
+ greater_is_better=True,
396
+ fp16=args.fp16 and torch.cuda.is_available(),
397
+ report_to="none",
398
+ seed=args.seed,
399
+ push_to_hub=False, # we push manually after training
400
+ )
401
+
402
+ # Trainer
403
+ trainer = Trainer(
404
+ model=model,
405
+ args=training_args,
406
+ train_dataset=tokenized_dataset["train"],
407
+ eval_dataset=tokenized_dataset["validation"],
408
+ tokenizer=tokenizer,
409
+ compute_metrics=build_compute_metrics_fn(),
410
+ callbacks=[
411
+ EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience),
412
+ ],
413
+ )
414
+
415
+ # Train
416
+ logger.info("Starting training ...")
417
+ train_result = trainer.train()
418
+ logger.info("Training complete.")
419
+
420
+ # Log final training metrics
421
+ metrics = train_result.metrics
422
+ trainer.log_metrics("train", metrics)
423
+ trainer.save_metrics("train", metrics)
424
+
425
+ # Evaluate
426
+ logger.info("Running final evaluation ...")
427
+ eval_metrics = trainer.evaluate()
428
+ trainer.log_metrics("eval", eval_metrics)
429
+ trainer.save_metrics("eval", eval_metrics)
430
+
431
+ # Save model + tokenizer
432
+ model_dir = Path(args.model_dir)
433
+ model_dir.mkdir(parents=True, exist_ok=True)
434
+ logger.info("Saving model to %s", model_dir)
435
+ trainer.save_model(str(model_dir))
436
+ tokenizer.save_pretrained(str(model_dir))
437
+
438
+ # Push to Hub
439
+ if args.push_to_hub:
440
+ logger.info("Pushing model to HuggingFace Hub: %s", args.hub_model_id)
441
+ try:
442
+ model.push_to_hub(args.hub_model_id)
443
+ tokenizer.push_to_hub(args.hub_model_id)
444
+ logger.info("Model pushed successfully.")
445
+ except Exception:
446
+ logger.exception("Failed to push model to Hub.")
447
+ sys.exit(1)
448
+
449
+ logger.info("All done.")
450
+
451
+
452
+ if __name__ == "__main__":
453
+ main()