Adive01 commited on
Commit
6aef09e
Β·
verified Β·
1 Parent(s): ce2bcea

Upload mlplo/common.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/common.py +301 -0
mlplo/common.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import re
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any, Callable
9
+
10
+ import numpy as np
11
+ from datasets import Dataset
12
+ import torch
13
+ from transformers import AutoTokenizer
14
+
15
+ LOGGER = logging.getLogger(__name__)
16
+
17
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
18
+ PACKAGE_ROOT = PROJECT_ROOT / "mlplo"
19
+ DATA_DIR = PACKAGE_ROOT / "data"
20
+ PROCESSED_DIR = DATA_DIR / "processed"
21
+ CACHE_DIR = DATA_DIR / "cache"
22
+ CHECKPOINT_DIR = PACKAGE_ROOT / "checkpoints"
23
+ ARTIFACT_DIR = PACKAGE_ROOT / "artifacts"
24
+
25
+ DEFAULT_MODEL_NAME = "facebook/bart-large-xsum"
26
+ DEFAULT_DATASET_NAME = "xsum"
27
+ DEFAULT_TEXT_COLUMN = "document"
28
+ DEFAULT_SUMMARY_COLUMN = "summary"
29
+ DEFAULT_APP_FALLBACK_MODEL = "Adive01/bart-large-xsum-finetuned"
30
+ DEFAULT_INPUT_MAX_LENGTH = 1024
31
+ DEFAULT_TARGET_MAX_LENGTH = 96
32
+
33
+ # datasets uses fork-based multiprocessing which is unreliable on Windows
34
+ IS_WINDOWS = sys.platform == "win32"
35
+
36
+
37
+ # ── Directory helpers ──────────────────────────────────────────────────────────
38
+
39
+ def ensure_project_dirs() -> None:
40
+ for directory in (DATA_DIR, PROCESSED_DIR, CACHE_DIR, CHECKPOINT_DIR, ARTIFACT_DIR):
41
+ directory.mkdir(parents=True, exist_ok=True)
42
+
43
+
44
+ # ── Text utilities ─────────────────────────────────────────────────────────────
45
+
46
+ def normalize_text(text: object) -> str:
47
+ """Coerce *any* value to a clean, readable string stripped of web artifacts.
48
+
49
+ Removes noise that degrades BART's summaries when text is pasted from websites:
50
+ cookie banners, share buttons, ad labels, bylines, etc.
51
+ """
52
+ if text is None:
53
+ return ""
54
+ try:
55
+ raw = str(text)
56
+ except Exception:
57
+ return ""
58
+
59
+ # Normalise whitespace first
60
+ cleaned = raw.replace("\u00a0", " ")
61
+ cleaned = re.sub(r"[\r\n\t]+", " ", cleaned)
62
+
63
+ # Strip common web-page junk patterns
64
+ WEB_JUNK = [
65
+ r"scroll down for video\.?",
66
+ r"advertisement\.?",
67
+ r"share this article\.?",
68
+ r"click here to\s+\w+[^.]*\.",
69
+ r"cookie(s)? (policy|notice|settings)[^.]*\.",
70
+ r"by [A-Z][a-z]+ [A-Z][a-z]+\s*\|", # bylines "By John Smith |"
71
+ r"\d{1,2}\s+(january|february|march|april|may|june|july|august|september|october|november|december)\s+\d{4}",
72
+ r"published:?\s*\d{1,2}[:/]\d{1,2}",
73
+ r"updated:?\s*\d{1,2}[:/]\d{1,2}",
74
+ r"follow us on (twitter|facebook|instagram|linkedin)[^.]*\.",
75
+ r"subscribe (to|for)[^.]*\.",
76
+ r"sign up[^.]*newsletter[^.]*\.",
77
+ r"\[.*?\]", # [image caption], [video], etc.
78
+ r"read more:?[^.]*\.",
79
+ r"related:?[^.]*\.",
80
+ ]
81
+ for pattern in WEB_JUNK:
82
+ cleaned = re.sub(pattern, " ", cleaned, flags=re.IGNORECASE)
83
+
84
+ # Collapse multiple spaces
85
+ cleaned = re.sub(r"\s+", " ", cleaned)
86
+ return cleaned.strip()
87
+
88
+
89
+
90
+ def count_words(text: str) -> int:
91
+ if not isinstance(text, str):
92
+ return 0
93
+ return len(text.split())
94
+
95
+
96
+ def shorten_model_name(path_or_name: str) -> str:
97
+ if not path_or_name:
98
+ return "Unknown"
99
+ path = Path(path_or_name)
100
+ if path.exists() or path.is_absolute():
101
+ return path.name
102
+ return path_or_name
103
+
104
+
105
+ # ── I/O helpers ────────────────────────────────────────────────────────────────
106
+
107
+ def write_json(path: str | Path, payload: dict[str, Any]) -> None:
108
+ output_path = Path(path)
109
+ output_path.parent.mkdir(parents=True, exist_ok=True)
110
+ output_path.write_text(
111
+ json.dumps(payload, indent=2, ensure_ascii=True) + "\n", encoding="utf-8"
112
+ )
113
+
114
+
115
+ def write_jsonl(path: str | Path, rows: list[dict[str, Any]]) -> None:
116
+ output_path = Path(path)
117
+ output_path.parent.mkdir(parents=True, exist_ok=True)
118
+ lines = [json.dumps(row, ensure_ascii=True) for row in rows]
119
+ output_path.write_text(
120
+ "\n".join(lines) + ("\n" if lines else ""), encoding="utf-8"
121
+ )
122
+
123
+
124
+ def load_json(path: str | Path) -> dict[str, Any]:
125
+ return json.loads(Path(path).read_text(encoding="utf-8"))
126
+
127
+
128
+ # ── Model / tokenizer helpers ──────────────────────────────────────────────────
129
+
130
+ def load_tokenizer(model_name: str):
131
+ return AutoTokenizer.from_pretrained(model_name, use_fast=True)
132
+
133
+
134
+ def build_preprocess_function(
135
+ tokenizer,
136
+ text_column: str,
137
+ summary_column: str,
138
+ max_input_length: int,
139
+ max_target_length: int,
140
+ ) -> Callable[[dict[str, list[str]]], dict[str, list[list[int]]]]:
141
+ """Return a batched map function that tokenizes source + target texts."""
142
+
143
+ def preprocess(batch: dict[str, list[str]]) -> dict[str, list[list[int]]]:
144
+ if text_column not in batch:
145
+ raise KeyError(
146
+ f"Text column '{text_column}' not found in batch. "
147
+ f"Available columns: {list(batch.keys())}"
148
+ )
149
+ if summary_column not in batch:
150
+ raise KeyError(
151
+ f"Summary column '{summary_column}' not found in batch. "
152
+ f"Available columns: {list(batch.keys())}"
153
+ )
154
+ model_inputs = tokenizer(
155
+ batch[text_column],
156
+ max_length=max_input_length,
157
+ truncation=True,
158
+ )
159
+ labels = tokenizer(
160
+ text_target=batch[summary_column],
161
+ max_length=max_target_length,
162
+ truncation=True,
163
+ )
164
+ model_inputs["labels"] = labels["input_ids"]
165
+ return model_inputs
166
+
167
+ return preprocess
168
+
169
+
170
+ def resolve_mixed_precision() -> dict[str, bool]:
171
+ if not torch.cuda.is_available():
172
+ return {"fp16": False, "bf16": False}
173
+ try:
174
+ bf16_available = torch.cuda.is_bf16_supported()
175
+ except (AttributeError, RuntimeError, AssertionError):
176
+ bf16_available = False
177
+ return {"fp16": not bf16_available, "bf16": bf16_available}
178
+
179
+
180
+ def default_device() -> torch.device:
181
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
182
+
183
+
184
+ def existing_default_checkpoint() -> str | None:
185
+ """Return the most recently modified valid checkpoint directory, or None.
186
+
187
+ A directory is considered a valid checkpoint if it contains either
188
+ ``model.safetensors`` or ``pytorch_model.bin``.
189
+ """
190
+ if not CHECKPOINT_DIR.exists():
191
+ return None
192
+ candidates: list[Path] = []
193
+ for entry in CHECKPOINT_DIR.rglob("*"):
194
+ if entry.is_dir():
195
+ has_model = (
196
+ (entry / "model.safetensors").exists()
197
+ or (entry / "pytorch_model.bin").exists()
198
+ )
199
+ if has_model:
200
+ candidates.append(entry)
201
+ if not candidates:
202
+ return None
203
+ return str(max(candidates, key=lambda p: p.stat().st_mtime))
204
+
205
+
206
+ def resolve_model_reference(path_or_name: str | None, fallback: str | None = None) -> str:
207
+ if path_or_name:
208
+ candidate = Path(path_or_name)
209
+ return str(candidate.resolve()) if candidate.exists() else path_or_name
210
+ if fallback:
211
+ return fallback
212
+ raise ValueError("A model path or model name is required.")
213
+
214
+
215
+ def validate_model_dir(path: str | Path) -> None:
216
+ """Raise FileNotFoundError with a clear message if a checkpoint dir looks incomplete."""
217
+ p = Path(path)
218
+ if not p.exists():
219
+ raise FileNotFoundError(f"Model path does not exist: {p}")
220
+ has_weights = (p / "model.safetensors").exists() or (p / "pytorch_model.bin").exists()
221
+ if not has_weights:
222
+ raise FileNotFoundError(
223
+ f"No model weights found in '{p}'. "
224
+ "Expected 'model.safetensors' or 'pytorch_model.bin'."
225
+ )
226
+
227
+
228
+ # ── Dataset helpers (single source of truth) ──────────────────────────────────
229
+
230
+ def maybe_limit_split(split: Dataset, limit: int | None) -> Dataset:
231
+ """Select the first *limit* rows from a Dataset split, or return it unchanged."""
232
+ if limit is None or limit >= len(split):
233
+ return split
234
+ return split.select(range(limit))
235
+
236
+
237
+ # ── Metrics (single source of truth) ──────────────────────────────────────────
238
+
239
+ def build_compute_metrics(tokenizer, *, include_bertscore: bool = False):
240
+ """Return a ``compute_metrics`` callable suitable for ``Seq2SeqTrainer``.
241
+
242
+ Parameters
243
+ ----------
244
+ tokenizer:
245
+ Used to decode predicted and label token IDs.
246
+ include_bertscore:
247
+ When ``True``, also compute BERTScore F1 (requires ``bert-score``).
248
+ Keep ``False`` during training β€” BERTScore downloads a ~400 MB model
249
+ on first use and is 10-20Γ— slower than ROUGE. Set ``True`` only for
250
+ standalone evaluation passes (``mlplo.eval``).
251
+ """
252
+ import evaluate # deferred: keeps module importable without evaluate installed
253
+
254
+ rouge = evaluate.load("rouge")
255
+
256
+ def compute_metrics(eval_prediction):
257
+ predictions, labels = eval_prediction
258
+ if isinstance(predictions, tuple):
259
+ predictions = predictions[0]
260
+
261
+ predictions = np.asarray(predictions)
262
+ predictions = np.where(predictions < 0, tokenizer.pad_token_id, predictions)
263
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
264
+
265
+ labels = np.asarray(labels)
266
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
267
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
268
+
269
+ decoded_predictions = [p.strip() for p in decoded_predictions]
270
+ decoded_labels = [lb.strip() for lb in decoded_labels]
271
+
272
+ rouge_result = rouge.compute(
273
+ predictions=decoded_predictions,
274
+ references=decoded_labels,
275
+ use_stemmer=True,
276
+ )
277
+
278
+ prediction_lengths = [
279
+ int(np.count_nonzero(pred != tokenizer.pad_token_id))
280
+ for pred in predictions
281
+ ]
282
+
283
+ metrics: dict[str, float] = {
284
+ "rouge1": round(rouge_result["rouge1"], 4),
285
+ "rouge2": round(rouge_result["rouge2"], 4),
286
+ "rougeL": round(rouge_result["rougeL"], 4),
287
+ "gen_len": round(float(np.mean(prediction_lengths)), 2),
288
+ }
289
+
290
+ if include_bertscore:
291
+ from bert_score import score as bert_score_fn
292
+
293
+ LOGGER.info("Computing BERTScore (downloads model on first use)…")
294
+ safe_preds = [p if p.strip() else "..." for p in decoded_predictions]
295
+ safe_labels = [lb if lb.strip() else "..." for lb in decoded_labels]
296
+ _, _, F1 = bert_score_fn(safe_preds, safe_labels, lang="en", verbose=False)
297
+ metrics["bertscore_f1"] = round(float(F1.mean().item()), 4)
298
+
299
+ return metrics
300
+
301
+ return compute_metrics