OliverPerrin commited on
Commit
f9d964d
·
1 Parent(s): f6d689c

Style: Fix linting errors and organize imports (ruff & mypy)

Browse files
scripts/download_data.py CHANGED
@@ -13,14 +13,12 @@ from urllib.request import urlopen
13
 
14
  from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
15
 
16
-
17
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
18
  if str(PROJECT_ROOT) not in sys.path:
19
  sys.path.insert(0, str(PROJECT_ROOT))
20
 
21
  from src.utils.config import load_yaml
22
 
23
-
24
  DOWNLOAD_TIMEOUT = 60
25
  DEFAULT_SUMMARIZATION_DATASET = "gowrishankarp/newspaper-text-summarization-cnn-dailymail"
26
  DEFAULT_EMOTION_DATASET = "dair-ai/emotion"
@@ -33,16 +31,19 @@ def kaggle_download(dataset: str, output_dir: str) -> None:
33
  target = Path(output_dir)
34
  target.mkdir(parents=True, exist_ok=True)
35
  try:
36
- run([
37
- "kaggle",
38
- "datasets",
39
- "download",
40
- "-d",
41
- dataset,
42
- "-p",
43
- str(target),
44
- "--unzip",
45
- ], check=True)
 
 
 
46
  except CalledProcessError as error:
47
  raise RuntimeError(
48
  "Kaggle download failed. Verify that the Kaggle CLI is authenticated,"
@@ -71,8 +72,14 @@ def parse_args() -> argparse.Namespace:
71
  default="configs/data/datasets.yaml",
72
  help="Path to the dataset configuration YAML.",
73
  )
74
- parser.add_argument("--skip-kaggle", action="store_true", help="Skip downloading the Kaggle summarization dataset.")
75
- parser.add_argument("--skip-book", action="store_true", help="Skip downloading Gutenberg book texts.")
 
 
 
 
 
 
76
  return parser.parse_args()
77
 
78
 
@@ -92,11 +99,14 @@ def _write_jsonl(records: Iterable[dict[str, object]], destination: Path) -> Non
92
  handle.write(json.dumps(record, ensure_ascii=False) + "\n")
93
 
94
 
95
- def _emotion_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]:
 
 
96
  for item in dataset_split:
97
  data = dict(item)
98
  text = data.get("text", "")
99
  label_value = data.get("label")
 
100
  def resolve_label(index: object) -> str:
101
  if isinstance(index, int) and label_names and 0 <= index < len(label_names):
102
  return label_names[index]
@@ -109,11 +119,14 @@ def _emotion_records(dataset_split: Dataset, label_names: list[str] | None) -> I
109
  yield {"text": text, "emotions": labels}
110
 
111
 
112
- def _topic_records(dataset_split: Dataset, label_names: list[str] | None) -> Iterator[dict[str, object]]:
 
 
113
  for item in dataset_split:
114
  data = dict(item)
115
  text = data.get("text") or data.get("content") or ""
116
  label_value = data.get("label")
 
117
  def resolve_topic(raw: object) -> str:
118
  if label_names:
119
  idx: int | None = None
@@ -142,12 +155,18 @@ def main() -> None:
142
  raw_paths = config.get("raw", {}) if isinstance(config, dict) else {}
143
  downloads_cfg = config.get("downloads", {}) if isinstance(config, dict) else {}
144
 
145
- summarization_cfg = downloads_cfg.get("summarization", {}) if isinstance(downloads_cfg, dict) else {}
 
 
146
  summarization_dataset = summarization_cfg.get("dataset", DEFAULT_SUMMARIZATION_DATASET)
147
- summarization_output = summarization_cfg.get("output", raw_paths.get("summarization", "data/raw/summarization"))
 
 
148
 
149
  if not args.skip_kaggle and summarization_dataset:
150
- print(f"Downloading summarization dataset '{summarization_dataset}' -> {summarization_output}")
 
 
151
  kaggle_download(summarization_dataset, summarization_output)
152
  else:
153
  print("Skipping Kaggle summarization download.")
@@ -174,7 +193,11 @@ def main() -> None:
174
  name = str(entry.get("name") or "gutenberg_text")
175
  url = str(entry.get("url") or DEFAULT_BOOK_URL)
176
  output_value = entry.get("output")
177
- destination = Path(output_value) if isinstance(output_value, str) and output_value else books_root / f"{name}.txt"
 
 
 
 
178
  destination.parent.mkdir(parents=True, exist_ok=True)
179
  print(f"Downloading Gutenberg text '{name}' from {url} -> {destination}")
180
  gutenberg_download(url, str(destination))
@@ -192,7 +215,9 @@ def main() -> None:
192
  if first_emotion_key is not None
193
  else None
194
  )
195
- emotion_label_names = emotion_label_feature.names if isinstance(emotion_label_feature, ClassLabel) else None
 
 
196
  for split_name, split in emotion_dataset.items():
197
  output_path = emotion_dir / f"{str(split_name)}.jsonl"
198
  _write_jsonl(_emotion_records(split, emotion_label_names), output_path)
@@ -209,7 +234,9 @@ def main() -> None:
209
  if first_topic_key is not None
210
  else None
211
  )
212
- topic_label_names = topic_label_feature.names if isinstance(topic_label_feature, ClassLabel) else None
 
 
213
  for split_name, split in topic_dataset.items():
214
  output_path = topic_dir / f"{str(split_name)}.jsonl"
215
  _write_jsonl(_topic_records(split, topic_label_names), output_path)
 
13
 
14
  from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
15
 
 
16
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
17
  if str(PROJECT_ROOT) not in sys.path:
18
  sys.path.insert(0, str(PROJECT_ROOT))
19
 
20
  from src.utils.config import load_yaml
21
 
 
22
  DOWNLOAD_TIMEOUT = 60
23
  DEFAULT_SUMMARIZATION_DATASET = "gowrishankarp/newspaper-text-summarization-cnn-dailymail"
24
  DEFAULT_EMOTION_DATASET = "dair-ai/emotion"
 
31
  target = Path(output_dir)
32
  target.mkdir(parents=True, exist_ok=True)
33
  try:
34
+ run(
35
+ [
36
+ "kaggle",
37
+ "datasets",
38
+ "download",
39
+ "-d",
40
+ dataset,
41
+ "-p",
42
+ str(target),
43
+ "--unzip",
44
+ ],
45
+ check=True,
46
+ )
47
  except CalledProcessError as error:
48
  raise RuntimeError(
49
  "Kaggle download failed. Verify that the Kaggle CLI is authenticated,"
 
72
  default="configs/data/datasets.yaml",
73
  help="Path to the dataset configuration YAML.",
74
  )
75
+ parser.add_argument(
76
+ "--skip-kaggle",
77
+ action="store_true",
78
+ help="Skip downloading the Kaggle summarization dataset.",
79
+ )
80
+ parser.add_argument(
81
+ "--skip-book", action="store_true", help="Skip downloading Gutenberg book texts."
82
+ )
83
  return parser.parse_args()
84
 
85
 
 
99
  handle.write(json.dumps(record, ensure_ascii=False) + "\n")
100
 
101
 
102
+ def _emotion_records(
103
+ dataset_split: Dataset, label_names: list[str] | None
104
+ ) -> Iterator[dict[str, object]]:
105
  for item in dataset_split:
106
  data = dict(item)
107
  text = data.get("text", "")
108
  label_value = data.get("label")
109
+
110
  def resolve_label(index: object) -> str:
111
  if isinstance(index, int) and label_names and 0 <= index < len(label_names):
112
  return label_names[index]
 
119
  yield {"text": text, "emotions": labels}
120
 
121
 
122
+ def _topic_records(
123
+ dataset_split: Dataset, label_names: list[str] | None
124
+ ) -> Iterator[dict[str, object]]:
125
  for item in dataset_split:
126
  data = dict(item)
127
  text = data.get("text") or data.get("content") or ""
128
  label_value = data.get("label")
129
+
130
  def resolve_topic(raw: object) -> str:
131
  if label_names:
132
  idx: int | None = None
 
155
  raw_paths = config.get("raw", {}) if isinstance(config, dict) else {}
156
  downloads_cfg = config.get("downloads", {}) if isinstance(config, dict) else {}
157
 
158
+ summarization_cfg = (
159
+ downloads_cfg.get("summarization", {}) if isinstance(downloads_cfg, dict) else {}
160
+ )
161
  summarization_dataset = summarization_cfg.get("dataset", DEFAULT_SUMMARIZATION_DATASET)
162
+ summarization_output = summarization_cfg.get(
163
+ "output", raw_paths.get("summarization", "data/raw/summarization")
164
+ )
165
 
166
  if not args.skip_kaggle and summarization_dataset:
167
+ print(
168
+ f"Downloading summarization dataset '{summarization_dataset}' -> {summarization_output}"
169
+ )
170
  kaggle_download(summarization_dataset, summarization_output)
171
  else:
172
  print("Skipping Kaggle summarization download.")
 
193
  name = str(entry.get("name") or "gutenberg_text")
194
  url = str(entry.get("url") or DEFAULT_BOOK_URL)
195
  output_value = entry.get("output")
196
+ destination = (
197
+ Path(output_value)
198
+ if isinstance(output_value, str) and output_value
199
+ else books_root / f"{name}.txt"
200
+ )
201
  destination.parent.mkdir(parents=True, exist_ok=True)
202
  print(f"Downloading Gutenberg text '{name}' from {url} -> {destination}")
203
  gutenberg_download(url, str(destination))
 
215
  if first_emotion_key is not None
216
  else None
217
  )
218
+ emotion_label_names = (
219
+ emotion_label_feature.names if isinstance(emotion_label_feature, ClassLabel) else None
220
+ )
221
  for split_name, split in emotion_dataset.items():
222
  output_path = emotion_dir / f"{str(split_name)}.jsonl"
223
  _write_jsonl(_emotion_records(split, emotion_label_names), output_path)
 
234
  if first_topic_key is not None
235
  else None
236
  )
237
+ topic_label_names = (
238
+ topic_label_feature.names if isinstance(topic_label_feature, ClassLabel) else None
239
+ )
240
  for split_name, split in topic_dataset.items():
241
  output_path = topic_dir / f"{str(split_name)}.jsonl"
242
  _write_jsonl(_topic_records(split, topic_label_names), output_path)
scripts/eval_rouge.py CHANGED
@@ -3,181 +3,195 @@ from __future__ import annotations
3
 
4
  import argparse
5
  import json
 
6
  from collections import defaultdict
7
  from pathlib import Path
8
  from statistics import fmean
9
  from typing import Dict, Iterable, List, Sequence, Tuple
10
 
11
- import sys
12
-
13
  from rouge_score import rouge_scorer
14
  from tqdm import tqdm
15
 
16
  PROJECT_ROOT = Path(__file__).resolve().parent.parent
17
  if str(PROJECT_ROOT) not in sys.path:
18
- sys.path.insert(0, str(PROJECT_ROOT))
19
 
20
  from src.inference.factory import create_inference_pipeline
21
 
22
 
23
  def parse_args() -> argparse.Namespace:
24
- parser = argparse.ArgumentParser(description="Evaluate LexiMind summaries with ROUGE metrics.")
25
- parser.add_argument("data", type=Path, help="Path to JSONL file with source text and gold summaries.")
26
- parser.add_argument("checkpoint", type=Path, help="Path to the trained checkpoint (e.g., checkpoints/best.pt).")
27
- parser.add_argument("labels", type=Path, help="Path to label metadata (e.g., artifacts/labels.json).")
28
- parser.add_argument(
29
- "--tokenizer-dir",
30
- type=Path,
31
- default=Path("artifacts/hf_tokenizer"),
32
- help="Directory containing the saved tokenizer artifacts.",
33
- )
34
- parser.add_argument(
35
- "--model-config",
36
- type=Path,
37
- default=None,
38
- help="Optional YAML config describing the model architecture.",
39
- )
40
- parser.add_argument("--device", type=str, default="cpu", help="Device to run inference on (cpu or cuda).")
41
- parser.add_argument("--batch-size", type=int, default=8, help="Number of samples per inference batch.")
42
- parser.add_argument(
43
- "--max-samples",
44
- type=int,
45
- default=None,
46
- help="If provided, limit evaluation to the first N samples for quick smoke tests.",
47
- )
48
- parser.add_argument(
49
- "--max-length",
50
- type=int,
51
- default=128,
52
- help="Maximum length to pass into the summarization head during generation.",
53
- )
54
- parser.add_argument(
55
- "--metrics",
56
- type=str,
57
- nargs="+",
58
- default=("rouge1", "rouge2", "rougeL"),
59
- help="ROUGE metrics to compute.",
60
- )
61
- parser.add_argument(
62
- "--source-field",
63
- type=str,
64
- default="source",
65
- help="Field name containing the input document in the JSONL examples.",
66
- )
67
- parser.add_argument(
68
- "--target-field",
69
- type=str,
70
- default="summary",
71
- help="Field name containing the reference summary in the JSONL examples.",
72
- )
73
- parser.add_argument(
74
- "--no-stemmer",
75
- action="store_true",
76
- help="Disable Porter stemming inside the ROUGE scorer (defaults to enabled).",
77
- )
78
- parser.add_argument(
79
- "--output",
80
- type=Path,
81
- default=None,
82
- help="Optional path to save a JSON report with aggregate metrics and sample counts.",
83
- )
84
- return parser.parse_args()
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  def load_examples(
88
- path: Path,
89
- source_field: str,
90
- target_field: str,
91
- max_samples: int | None,
92
  ) -> List[Tuple[str, str]]:
93
- examples: List[Tuple[str, str]] = []
94
- with path.open("r", encoding="utf-8") as handle:
95
- for line in handle:
96
- line = line.strip()
97
- if not line:
98
- continue
99
- record = json.loads(line)
100
- try:
101
- source = str(record[source_field])
102
- target = str(record[target_field])
103
- except KeyError as exc: # pragma: no cover - invalid data surface at runtime
104
- raise KeyError(f"Missing field in record: {exc} (available keys: {list(record)})") from exc
105
- examples.append((source, target))
106
- if max_samples is not None and len(examples) >= max_samples:
107
- break
108
- if not examples:
109
- raise ValueError(f"No examples loaded from {path}")
110
- return examples
111
-
112
-
113
- def batched(items: Sequence[Tuple[str, str]], batch_size: int) -> Iterable[Sequence[Tuple[str, str]]]:
114
- for start in range(0, len(items), batch_size):
115
- yield items[start : start + batch_size]
 
 
 
 
116
 
117
 
118
  def aggregate_scores(raw_scores: Dict[str, Dict[str, List[float]]]) -> Dict[str, Dict[str, float]]:
119
- aggregated: Dict[str, Dict[str, float]] = {}
120
- for metric, components in raw_scores.items():
121
- aggregated[metric] = {
122
- component: (fmean(values) if values else 0.0) for component, values in components.items()
123
- }
124
- return aggregated
 
125
 
126
 
127
  def main() -> None:
128
- args = parse_args()
129
-
130
- pipeline, _ = create_inference_pipeline(
131
- checkpoint_path=args.checkpoint,
132
- labels_path=args.labels,
133
- tokenizer_dir=args.tokenizer_dir,
134
- model_config_path=args.model_config,
135
- device=args.device,
136
- summary_max_length=args.max_length,
137
- )
138
-
139
- examples = load_examples(args.data, args.source_field, args.target_field, args.max_samples)
140
- scorer = rouge_scorer.RougeScorer(list(args.metrics), use_stemmer=not args.no_stemmer)
141
-
142
- score_store: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
143
-
144
- for batch in tqdm(
145
- list(batched(examples, args.batch_size)),
146
- desc="Evaluating",
147
- total=(len(examples) + args.batch_size - 1) // args.batch_size,
148
- ):
149
- documents = [item[0] for item in batch]
150
- references = [item[1] for item in batch]
151
- predictions = pipeline.summarize(documents, max_length=args.max_length)
152
-
153
- for reference, prediction in zip(references, predictions):
154
- scores = scorer.score(reference, prediction)
155
- for metric_name, score in scores.items():
156
- score_store[metric_name]["precision"].append(score.precision)
157
- score_store[metric_name]["recall"].append(score.recall)
158
- score_store[metric_name]["fmeasure"].append(score.fmeasure)
159
-
160
- aggregated = aggregate_scores(score_store)
161
- report = {
162
- "num_examples": len(examples),
163
- "metrics": aggregated,
164
- "config": {
165
- "data": str(args.data),
166
- "checkpoint": str(args.checkpoint),
167
- "tokenizer_dir": str(args.tokenizer_dir),
168
- "metrics": list(args.metrics),
169
- "max_length": args.max_length,
170
- "batch_size": args.batch_size,
171
- "device": args.device,
172
- },
173
- }
174
-
175
- print(json.dumps(report, indent=2))
176
- if args.output:
177
- args.output.parent.mkdir(parents=True, exist_ok=True)
178
- with args.output.open("w", encoding="utf-8") as handle:
179
- json.dump(report, handle, ensure_ascii=False, indent=2)
180
 
181
 
182
  if __name__ == "__main__":
183
- main()
 
3
 
4
  import argparse
5
  import json
6
+ import sys
7
  from collections import defaultdict
8
  from pathlib import Path
9
  from statistics import fmean
10
  from typing import Dict, Iterable, List, Sequence, Tuple
11
 
 
 
12
  from rouge_score import rouge_scorer
13
  from tqdm import tqdm
14
 
15
  PROJECT_ROOT = Path(__file__).resolve().parent.parent
16
  if str(PROJECT_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(PROJECT_ROOT))
18
 
19
  from src.inference.factory import create_inference_pipeline
20
 
21
 
22
  def parse_args() -> argparse.Namespace:
23
+ parser = argparse.ArgumentParser(description="Evaluate LexiMind summaries with ROUGE metrics.")
24
+ parser.add_argument(
25
+ "data", type=Path, help="Path to JSONL file with source text and gold summaries."
26
+ )
27
+ parser.add_argument(
28
+ "checkpoint", type=Path, help="Path to the trained checkpoint (e.g., checkpoints/best.pt)."
29
+ )
30
+ parser.add_argument(
31
+ "labels", type=Path, help="Path to label metadata (e.g., artifacts/labels.json)."
32
+ )
33
+ parser.add_argument(
34
+ "--tokenizer-dir",
35
+ type=Path,
36
+ default=Path("artifacts/hf_tokenizer"),
37
+ help="Directory containing the saved tokenizer artifacts.",
38
+ )
39
+ parser.add_argument(
40
+ "--model-config",
41
+ type=Path,
42
+ default=None,
43
+ help="Optional YAML config describing the model architecture.",
44
+ )
45
+ parser.add_argument(
46
+ "--device", type=str, default="cpu", help="Device to run inference on (cpu or cuda)."
47
+ )
48
+ parser.add_argument(
49
+ "--batch-size", type=int, default=8, help="Number of samples per inference batch."
50
+ )
51
+ parser.add_argument(
52
+ "--max-samples",
53
+ type=int,
54
+ default=None,
55
+ help="If provided, limit evaluation to the first N samples for quick smoke tests.",
56
+ )
57
+ parser.add_argument(
58
+ "--max-length",
59
+ type=int,
60
+ default=128,
61
+ help="Maximum length to pass into the summarization head during generation.",
62
+ )
63
+ parser.add_argument(
64
+ "--metrics",
65
+ type=str,
66
+ nargs="+",
67
+ default=("rouge1", "rouge2", "rougeL"),
68
+ help="ROUGE metrics to compute.",
69
+ )
70
+ parser.add_argument(
71
+ "--source-field",
72
+ type=str,
73
+ default="source",
74
+ help="Field name containing the input document in the JSONL examples.",
75
+ )
76
+ parser.add_argument(
77
+ "--target-field",
78
+ type=str,
79
+ default="summary",
80
+ help="Field name containing the reference summary in the JSONL examples.",
81
+ )
82
+ parser.add_argument(
83
+ "--no-stemmer",
84
+ action="store_true",
85
+ help="Disable Porter stemming inside the ROUGE scorer (defaults to enabled).",
86
+ )
87
+ parser.add_argument(
88
+ "--output",
89
+ type=Path,
90
+ default=None,
91
+ help="Optional path to save a JSON report with aggregate metrics and sample counts.",
92
+ )
93
+ return parser.parse_args()
94
 
95
 
96
  def load_examples(
97
+ path: Path,
98
+ source_field: str,
99
+ target_field: str,
100
+ max_samples: int | None,
101
  ) -> List[Tuple[str, str]]:
102
+ examples: List[Tuple[str, str]] = []
103
+ with path.open("r", encoding="utf-8") as handle:
104
+ for line in handle:
105
+ line = line.strip()
106
+ if not line:
107
+ continue
108
+ record = json.loads(line)
109
+ try:
110
+ source = str(record[source_field])
111
+ target = str(record[target_field])
112
+ except KeyError as exc: # pragma: no cover - invalid data surface at runtime
113
+ raise KeyError(
114
+ f"Missing field in record: {exc} (available keys: {list(record)})"
115
+ ) from exc
116
+ examples.append((source, target))
117
+ if max_samples is not None and len(examples) >= max_samples:
118
+ break
119
+ if not examples:
120
+ raise ValueError(f"No examples loaded from {path}")
121
+ return examples
122
+
123
+
124
+ def batched(
125
+ items: Sequence[Tuple[str, str]], batch_size: int
126
+ ) -> Iterable[Sequence[Tuple[str, str]]]:
127
+ for start in range(0, len(items), batch_size):
128
+ yield items[start : start + batch_size]
129
 
130
 
131
  def aggregate_scores(raw_scores: Dict[str, Dict[str, List[float]]]) -> Dict[str, Dict[str, float]]:
132
+ aggregated: Dict[str, Dict[str, float]] = {}
133
+ for metric, components in raw_scores.items():
134
+ aggregated[metric] = {
135
+ component: (fmean(values) if values else 0.0)
136
+ for component, values in components.items()
137
+ }
138
+ return aggregated
139
 
140
 
141
  def main() -> None:
142
+ args = parse_args()
143
+
144
+ pipeline, _ = create_inference_pipeline(
145
+ checkpoint_path=args.checkpoint,
146
+ labels_path=args.labels,
147
+ tokenizer_dir=args.tokenizer_dir,
148
+ model_config_path=args.model_config,
149
+ device=args.device,
150
+ summary_max_length=args.max_length,
151
+ )
152
+
153
+ examples = load_examples(args.data, args.source_field, args.target_field, args.max_samples)
154
+ scorer = rouge_scorer.RougeScorer(list(args.metrics), use_stemmer=not args.no_stemmer)
155
+
156
+ score_store: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
157
+
158
+ for batch in tqdm(
159
+ list(batched(examples, args.batch_size)),
160
+ desc="Evaluating",
161
+ total=(len(examples) + args.batch_size - 1) // args.batch_size,
162
+ ):
163
+ documents = [item[0] for item in batch]
164
+ references = [item[1] for item in batch]
165
+ predictions = pipeline.summarize(documents, max_length=args.max_length)
166
+
167
+ for reference, prediction in zip(references, predictions):
168
+ scores = scorer.score(reference, prediction)
169
+ for metric_name, score in scores.items():
170
+ score_store[metric_name]["precision"].append(score.precision)
171
+ score_store[metric_name]["recall"].append(score.recall)
172
+ score_store[metric_name]["fmeasure"].append(score.fmeasure)
173
+
174
+ aggregated = aggregate_scores(score_store)
175
+ report = {
176
+ "num_examples": len(examples),
177
+ "metrics": aggregated,
178
+ "config": {
179
+ "data": str(args.data),
180
+ "checkpoint": str(args.checkpoint),
181
+ "tokenizer_dir": str(args.tokenizer_dir),
182
+ "metrics": list(args.metrics),
183
+ "max_length": args.max_length,
184
+ "batch_size": args.batch_size,
185
+ "device": args.device,
186
+ },
187
+ }
188
+
189
+ print(json.dumps(report, indent=2))
190
+ if args.output:
191
+ args.output.parent.mkdir(parents=True, exist_ok=True)
192
+ with args.output.open("w", encoding="utf-8") as handle:
193
+ json.dump(report, handle, ensure_ascii=False, indent=2)
194
 
195
 
196
  if __name__ == "__main__":
197
+ main()
scripts/preprocess_data.py CHANGED
@@ -25,8 +25,15 @@ def parse_args() -> argparse.Namespace:
25
  default="configs/data/datasets.yaml",
26
  help="Path to data configuration YAML.",
27
  )
28
- parser.add_argument("--val-ratio", type=float, default=0.1, help="Validation split size for topic dataset when no validation split is present.")
29
- parser.add_argument("--seed", type=int, default=17, help="Random seed for deterministic splitting.")
 
 
 
 
 
 
 
30
  return parser.parse_args()
31
 
32
 
@@ -73,7 +80,9 @@ def preprocess_books(
73
  for book_path in sorted(raw_dir.glob("*.txt")):
74
  text = book_path.read_text(encoding="utf-8").lstrip("\ufeff")
75
  normalized = text.replace("\r\n", "\n")
76
- paragraphs = [paragraph.strip() for paragraph in normalized.split("\n\n") if paragraph.strip()]
 
 
77
 
78
  records: list[Dict[str, object]] = []
79
  for paragraph_id, paragraph in enumerate(paragraphs):
@@ -130,7 +139,9 @@ def preprocess_summarization(raw_dir: Path, processed_dir: Path) -> None:
130
  output_path = processed_dir / f"{split}.jsonl"
131
  output_path.parent.mkdir(parents=True, exist_ok=True)
132
  print(f"Writing summarization split '{split}' to {output_path}")
133
- with source_path.open("r", encoding="utf-8", newline="") as source_handle, output_path.open("w", encoding="utf-8") as sink:
 
 
134
  reader = csv.DictReader(source_handle)
135
  for row in reader:
136
  article = row.get("article") or row.get("Article") or ""
@@ -167,7 +178,7 @@ def preprocess_emotion(raw_dir: Path, processed_dir: Path, cleaner: BasicTextCle
167
  assert source_path is not None
168
  path = source_path
169
 
170
- def iter_records() -> Iterator[Dict[str, object]]:
171
  if path.suffix == ".jsonl":
172
  for row in _read_jsonl(path):
173
  raw_text = str(row.get("text", ""))
@@ -186,12 +197,12 @@ def preprocess_emotion(raw_dir: Path, processed_dir: Path, cleaner: BasicTextCle
186
  delimiter = ";" if path.suffix == ".txt" else ","
187
  with path.open("r", encoding="utf-8", newline="") as handle:
188
  reader = csv.reader(handle, delimiter=delimiter)
189
- for row in reader:
190
- if not row:
191
  continue
192
- raw_text = str(row[0])
193
  text = cleaner.transform([raw_text])[0]
194
- raw_labels = row[1] if len(row) > 1 else ""
195
  labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
196
  if not labels:
197
  labels = ["neutral"]
@@ -303,7 +314,9 @@ def main() -> None:
303
  topic_raw = Path(raw_cfg.get("topic", "data/raw/topic"))
304
 
305
  books_processed = Path(processed_cfg.get("books", "data/processed/books"))
306
- summarization_processed = Path(processed_cfg.get("summarization", "data/processed/summarization"))
 
 
307
  emotion_processed = Path(processed_cfg.get("emotion", "data/processed/emotion"))
308
  topic_processed = Path(processed_cfg.get("topic", "data/processed/topic"))
309
 
 
25
  default="configs/data/datasets.yaml",
26
  help="Path to data configuration YAML.",
27
  )
28
+ parser.add_argument(
29
+ "--val-ratio",
30
+ type=float,
31
+ default=0.1,
32
+ help="Validation split size for topic dataset when no validation split is present.",
33
+ )
34
+ parser.add_argument(
35
+ "--seed", type=int, default=17, help="Random seed for deterministic splitting."
36
+ )
37
  return parser.parse_args()
38
 
39
 
 
80
  for book_path in sorted(raw_dir.glob("*.txt")):
81
  text = book_path.read_text(encoding="utf-8").lstrip("\ufeff")
82
  normalized = text.replace("\r\n", "\n")
83
+ paragraphs = [
84
+ paragraph.strip() for paragraph in normalized.split("\n\n") if paragraph.strip()
85
+ ]
86
 
87
  records: list[Dict[str, object]] = []
88
  for paragraph_id, paragraph in enumerate(paragraphs):
 
139
  output_path = processed_dir / f"{split}.jsonl"
140
  output_path.parent.mkdir(parents=True, exist_ok=True)
141
  print(f"Writing summarization split '{split}' to {output_path}")
142
+ with source_path.open("r", encoding="utf-8", newline="") as source_handle, output_path.open(
143
+ "w", encoding="utf-8"
144
+ ) as sink:
145
  reader = csv.DictReader(source_handle)
146
  for row in reader:
147
  article = row.get("article") or row.get("Article") or ""
 
178
  assert source_path is not None
179
  path = source_path
180
 
181
+ def iter_records(path: Path = path) -> Iterator[Dict[str, object]]:
182
  if path.suffix == ".jsonl":
183
  for row in _read_jsonl(path):
184
  raw_text = str(row.get("text", ""))
 
197
  delimiter = ";" if path.suffix == ".txt" else ","
198
  with path.open("r", encoding="utf-8", newline="") as handle:
199
  reader = csv.reader(handle, delimiter=delimiter)
200
+ for csv_row in reader:
201
+ if not csv_row:
202
  continue
203
+ raw_text = str(csv_row[0])
204
  text = cleaner.transform([raw_text])[0]
205
+ raw_labels = csv_row[1] if len(csv_row) > 1 else ""
206
  labels = [label.strip() for label in raw_labels.split(",") if label.strip()]
207
  if not labels:
208
  labels = ["neutral"]
 
314
  topic_raw = Path(raw_cfg.get("topic", "data/raw/topic"))
315
 
316
  books_processed = Path(processed_cfg.get("books", "data/processed/books"))
317
+ summarization_processed = Path(
318
+ processed_cfg.get("summarization", "data/processed/summarization")
319
+ )
320
  emotion_processed = Path(processed_cfg.get("emotion", "data/processed/emotion"))
321
  topic_processed = Path(processed_cfg.get("topic", "data/processed/topic"))
322
 
src/api/dependencies.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  from fastapi import HTTPException, status
8
 
9
  from ..utils.logging import get_logger
 
10
  logger = get_logger(__name__)
11
 
12
  from ..inference.factory import create_inference_pipeline
 
7
  from fastapi import HTTPException, status
8
 
9
  from ..utils.logging import get_logger
10
+
11
  logger = get_logger(__name__)
12
 
13
  from ..inference.factory import create_inference_pipeline
src/api/routes.py CHANGED
@@ -11,7 +11,10 @@ router = APIRouter()
11
 
12
 
13
  @router.post("/summarize", response_model=SummaryResponse)
14
- def summarize(payload: SummaryRequest, pipeline: InferencePipeline = Depends(get_pipeline)) -> SummaryResponse:
 
 
 
15
  try:
16
  outputs = pipeline.batch_predict([payload.text])
17
  except Exception as exc: # noqa: BLE001 - surface inference error to client
 
11
 
12
 
13
  @router.post("/summarize", response_model=SummaryResponse)
14
+ def summarize(
15
+ payload: SummaryRequest,
16
+ pipeline: InferencePipeline = Depends(get_pipeline), # noqa: B008
17
+ ) -> SummaryResponse:
18
  try:
19
  outputs = pipeline.batch_predict([payload.text])
20
  except Exception as exc: # noqa: BLE001 - surface inference error to client
src/data/dataloader.py CHANGED
@@ -1,19 +1,32 @@
1
  """Task-aware DataLoader builders for the LexiMind multitask suite."""
2
  from __future__ import annotations
3
 
4
- from typing import Iterable, List
5
 
6
  import torch
7
  from torch.utils.data import DataLoader
8
 
9
- from .dataset import EmotionDataset, EmotionExample, SummarizationDataset, SummarizationExample, TopicDataset, TopicExample
 
 
 
 
 
 
 
10
  from .tokenization import Tokenizer
11
 
12
 
13
  class SummarizationCollator:
14
  """Prepare encoder-decoder batches for abstractive summarization."""
15
 
16
- def __init__(self, tokenizer: Tokenizer, *, max_source_length: int | None = None, max_target_length: int | None = None) -> None:
 
 
 
 
 
 
17
  self.tokenizer = tokenizer
18
  self.max_source_length = max_source_length
19
  self.max_target_length = max_target_length
@@ -29,17 +42,17 @@ class SummarizationCollator:
29
  # We want:
30
  # tgt_ids (decoder input): [BOS, A, B, EOS] (drop last PAD or EOS if full)
31
  # labels (target): [A, B, EOS, PAD] (drop first BOS)
32
-
33
  ids = target_enc["input_ids"]
34
  mask = target_enc["attention_mask"]
35
 
36
  # Slice to create shifted inputs/targets
37
  # tgt_ids: everything except the last token
38
  tgt_ids = ids[:, :-1]
39
-
40
  # labels: everything except the first token (BOS)
41
  labels = ids[:, 1:].clone()
42
-
43
  # Adjust mask for labels to ignore padding
44
  # The mask corresponds to the original ids. We slice it to match labels.
45
  labels_mask = mask[:, 1:]
@@ -56,7 +69,9 @@ class SummarizationCollator:
56
  class EmotionCollator:
57
  """Prepare batches for multi-label emotion classification."""
58
 
59
- def __init__(self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None) -> None:
 
 
60
  self.tokenizer = tokenizer
61
  self.binarizer = dataset.binarizer
62
  self.max_length = max_length
@@ -76,7 +91,9 @@ class EmotionCollator:
76
  class TopicCollator:
77
  """Prepare batches for topic classification using the projection head."""
78
 
79
- def __init__(self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None) -> None:
 
 
80
  self.tokenizer = tokenizer
81
  self.encoder = dataset.encoder
82
  self.max_length = max_length
@@ -84,7 +101,9 @@ class TopicCollator:
84
  def __call__(self, batch: List[TopicExample]) -> dict[str, torch.Tensor]:
85
  texts = [example.text for example in batch]
86
  encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
87
- labels = torch.as_tensor(self.encoder.transform([example.topic for example in batch]), dtype=torch.long)
 
 
88
  return {
89
  "input_ids": encoded["input_ids"],
90
  "attention_mask": encoded["attention_mask"],
 
1
  """Task-aware DataLoader builders for the LexiMind multitask suite."""
2
  from __future__ import annotations
3
 
4
+ from typing import List
5
 
6
  import torch
7
  from torch.utils.data import DataLoader
8
 
9
+ from .dataset import (
10
+ EmotionDataset,
11
+ EmotionExample,
12
+ SummarizationDataset,
13
+ SummarizationExample,
14
+ TopicDataset,
15
+ TopicExample,
16
+ )
17
  from .tokenization import Tokenizer
18
 
19
 
20
  class SummarizationCollator:
21
  """Prepare encoder-decoder batches for abstractive summarization."""
22
 
23
+ def __init__(
24
+ self,
25
+ tokenizer: Tokenizer,
26
+ *,
27
+ max_source_length: int | None = None,
28
+ max_target_length: int | None = None,
29
+ ) -> None:
30
  self.tokenizer = tokenizer
31
  self.max_source_length = max_source_length
32
  self.max_target_length = max_target_length
 
42
  # We want:
43
  # tgt_ids (decoder input): [BOS, A, B, EOS] (drop last PAD or EOS if full)
44
  # labels (target): [A, B, EOS, PAD] (drop first BOS)
45
+
46
  ids = target_enc["input_ids"]
47
  mask = target_enc["attention_mask"]
48
 
49
  # Slice to create shifted inputs/targets
50
  # tgt_ids: everything except the last token
51
  tgt_ids = ids[:, :-1]
52
+
53
  # labels: everything except the first token (BOS)
54
  labels = ids[:, 1:].clone()
55
+
56
  # Adjust mask for labels to ignore padding
57
  # The mask corresponds to the original ids. We slice it to match labels.
58
  labels_mask = mask[:, 1:]
 
69
  class EmotionCollator:
70
  """Prepare batches for multi-label emotion classification."""
71
 
72
+ def __init__(
73
+ self, tokenizer: Tokenizer, dataset: EmotionDataset, *, max_length: int | None = None
74
+ ) -> None:
75
  self.tokenizer = tokenizer
76
  self.binarizer = dataset.binarizer
77
  self.max_length = max_length
 
91
  class TopicCollator:
92
  """Prepare batches for topic classification using the projection head."""
93
 
94
+ def __init__(
95
+ self, tokenizer: Tokenizer, dataset: TopicDataset, *, max_length: int | None = None
96
+ ) -> None:
97
  self.tokenizer = tokenizer
98
  self.encoder = dataset.encoder
99
  self.max_length = max_length
 
101
  def __call__(self, batch: List[TopicExample]) -> dict[str, torch.Tensor]:
102
  texts = [example.text for example in batch]
103
  encoded = self.tokenizer.batch_encode(texts, max_length=self.max_length)
104
+ labels = torch.as_tensor(
105
+ self.encoder.transform([example.topic for example in batch]), dtype=torch.long
106
+ )
107
  return {
108
  "input_ids": encoded["input_ids"],
109
  "attention_mask": encoded["attention_mask"],
src/data/preprocessing.py CHANGED
@@ -1,13 +1,11 @@
1
  """Text preprocessing utilities built around Hugging Face tokenizers."""
2
  from __future__ import annotations
3
 
4
- import re
5
  from dataclasses import dataclass, replace
6
  from typing import Iterable, List, Sequence
7
 
8
  import torch
9
  from sklearn.base import BaseEstimator, TransformerMixin
10
- from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
11
 
12
  from .tokenization import Tokenizer, TokenizerConfig
13
 
 
1
  """Text preprocessing utilities built around Hugging Face tokenizers."""
2
  from __future__ import annotations
3
 
 
4
  from dataclasses import dataclass, replace
5
  from typing import Iterable, List, Sequence
6
 
7
  import torch
8
  from sklearn.base import BaseEstimator, TransformerMixin
 
9
 
10
  from .tokenization import Tokenizer, TokenizerConfig
11
 
src/inference/factory.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
 
9
  from ..data.preprocessing import TextPreprocessor
10
  from ..data.tokenization import Tokenizer, TokenizerConfig
11
- from ..models.factory import ModelConfig, build_multitask_model, load_model_config
12
  from ..utils.io import load_state
13
  from ..utils.labels import LabelMetadata, load_label_metadata
14
  from .pipeline import InferenceConfig, InferencePipeline
@@ -38,7 +38,9 @@ def create_inference_pipeline(
38
  chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir
39
  local_tokenizer_dir = chosen_dir
40
  if local_tokenizer_dir.exists():
41
- resolved_tokenizer_config = TokenizerConfig(pretrained_model_name=str(local_tokenizer_dir))
 
 
42
  else:
43
  raise ValueError(
44
  "No tokenizer configuration provided and default tokenizer directory "
@@ -46,11 +48,13 @@ def create_inference_pipeline(
46
  )
47
 
48
  tokenizer = Tokenizer(resolved_tokenizer_config)
49
-
50
  # Default to base config if not specified (checkpoint was trained with base config)
51
  if model_config_path is None:
52
- model_config_path = Path(__file__).resolve().parent.parent.parent / "configs" / "model" / "base.yaml"
53
-
 
 
54
  model_config = load_model_config(model_config_path)
55
  model = build_multitask_model(
56
  tokenizer,
@@ -59,7 +63,7 @@ def create_inference_pipeline(
59
  config=model_config,
60
  load_pretrained=False,
61
  )
62
-
63
  # Load checkpoint - weights will load separately since factory doesn't tie them
64
  load_state(model, str(checkpoint))
65
 
 
8
 
9
  from ..data.preprocessing import TextPreprocessor
10
  from ..data.tokenization import Tokenizer, TokenizerConfig
11
+ from ..models.factory import build_multitask_model, load_model_config
12
  from ..utils.io import load_state
13
  from ..utils.labels import LabelMetadata, load_label_metadata
14
  from .pipeline import InferenceConfig, InferencePipeline
 
38
  chosen_dir = Path(tokenizer_dir) if tokenizer_dir is not None else default_dir
39
  local_tokenizer_dir = chosen_dir
40
  if local_tokenizer_dir.exists():
41
+ resolved_tokenizer_config = TokenizerConfig(
42
+ pretrained_model_name=str(local_tokenizer_dir)
43
+ )
44
  else:
45
  raise ValueError(
46
  "No tokenizer configuration provided and default tokenizer directory "
 
48
  )
49
 
50
  tokenizer = Tokenizer(resolved_tokenizer_config)
51
+
52
  # Default to base config if not specified (checkpoint was trained with base config)
53
  if model_config_path is None:
54
+ model_config_path = (
55
+ Path(__file__).resolve().parent.parent.parent / "configs" / "model" / "base.yaml"
56
+ )
57
+
58
  model_config = load_model_config(model_config_path)
59
  model = build_multitask_model(
60
  tokenizer,
 
63
  config=model_config,
64
  load_pretrained=False,
65
  )
66
+
67
  # Load checkpoint - weights will load separately since factory doesn't tie them
68
  load_state(model, str(checkpoint))
69
 
src/models/__init__.py CHANGED
@@ -8,13 +8,13 @@ This package provides a from-scratch transformer implementation with:
8
  - MultiTaskModel: composable wrapper for encoder/decoder + task heads
9
  """
10
 
11
- from .encoder import TransformerEncoder, TransformerEncoderLayer
12
- from .decoder import TransformerDecoder, TransformerDecoderLayer, create_causal_mask
13
  from .attention import MultiHeadAttention
 
 
14
  from .feedforward import FeedForward
15
- from .positional_encoding import PositionalEncoding
16
- from .heads import ClassificationHead, TokenClassificationHead, LMHead, ProjectionHead
17
  from .multitask import MultiTaskModel
 
18
 
19
  __all__ = [
20
  "TransformerEncoder",
 
8
  - MultiTaskModel: composable wrapper for encoder/decoder + task heads
9
  """
10
 
 
 
11
  from .attention import MultiHeadAttention
12
+ from .decoder import TransformerDecoder, TransformerDecoderLayer, create_causal_mask
13
+ from .encoder import TransformerEncoder, TransformerEncoderLayer
14
  from .feedforward import FeedForward
15
+ from .heads import ClassificationHead, LMHead, ProjectionHead, TokenClassificationHead
 
16
  from .multitask import MultiTaskModel
17
+ from .positional_encoding import PositionalEncoding
18
 
19
  __all__ = [
20
  "TransformerEncoder",
src/models/heads.py CHANGED
@@ -9,7 +9,7 @@ Includes:
9
 
10
  Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
11
  """
12
- from typing import Optional, Literal
13
 
14
  import torch
15
  import torch.nn as nn
@@ -96,8 +96,12 @@ class LMHead(nn.Module):
96
 
97
  if tie_embedding is not None:
98
  # Validate sizes
99
- assert tie_embedding.num_embeddings == vocab_size, "vocab size mismatch for weight tying"
100
- assert tie_embedding.embedding_dim == d_model, "embedding dim must match d_model for weight tying"
 
 
 
 
101
  # Tie weights: point the projection weight to the embedding weight Tensor
102
  # Remove the existing projection parameter in favor of the embedding weight
103
  # This keeps the same Parameter object, so updates affect both modules.
@@ -122,7 +126,13 @@ class ProjectionHead(nn.Module):
122
  dropout: dropout probability
123
  """
124
 
125
- def __init__(self, d_model: int, proj_dim: int = 128, hidden_dim: Optional[int] = None, dropout: float = 0.1):
 
 
 
 
 
 
126
  super().__init__()
127
  if hidden_dim is None:
128
  hidden_dim = max(d_model, proj_dim)
@@ -148,4 +158,4 @@ class ProjectionHead(nn.Module):
148
  elif orig_dim == 2:
149
  return self.net(x)
150
  else:
151
- raise ValueError("Input must be 2D or 3D tensor")
 
9
 
10
  Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
11
  """
12
+ from typing import Literal, Optional
13
 
14
  import torch
15
  import torch.nn as nn
 
96
 
97
  if tie_embedding is not None:
98
  # Validate sizes
99
+ assert (
100
+ tie_embedding.num_embeddings == vocab_size
101
+ ), "vocab size mismatch for weight tying"
102
+ assert (
103
+ tie_embedding.embedding_dim == d_model
104
+ ), "embedding dim must match d_model for weight tying"
105
  # Tie weights: point the projection weight to the embedding weight Tensor
106
  # Remove the existing projection parameter in favor of the embedding weight
107
  # This keeps the same Parameter object, so updates affect both modules.
 
126
  dropout: dropout probability
127
  """
128
 
129
+ def __init__(
130
+ self,
131
+ d_model: int,
132
+ proj_dim: int = 128,
133
+ hidden_dim: Optional[int] = None,
134
+ dropout: float = 0.1,
135
+ ):
136
  super().__init__()
137
  if hidden_dim is None:
138
  hidden_dim = max(d_model, proj_dim)
 
158
  elif orig_dim == 2:
159
  return self.net(x)
160
  else:
161
+ raise ValueError("Input must be 2D or 3D tensor")
src/models/multitask.py CHANGED
@@ -14,16 +14,17 @@ Design goals:
14
  seq2seq tasks (encoder -> decoder -> LMHead)
15
  - Minimal dependencies on training loop; return logits and (optionally) loss
16
  """
17
- from typing import Optional, Dict, Any, Tuple
18
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
 
 
 
23
  # Import your components
24
  from .encoder import TransformerEncoder
25
- from .decoder import TransformerDecoder
26
- from .heads import ClassificationHead, TokenClassificationHead, LMHead
27
 
28
 
29
  class MultiTaskModel(nn.Module):
@@ -112,15 +113,21 @@ class MultiTaskModel(nn.Module):
112
  if "input_ids" in inputs:
113
  encoder_mask = None
114
  if "attention_mask" in inputs:
115
- encoder_mask = self._expand_attention_mask(inputs["attention_mask"], inputs["input_ids"].device)
 
 
116
  enc_out = self.encoder(inputs["input_ids"], mask=encoder_mask)
117
  elif "embeddings" in inputs:
118
  encoder_mask = inputs.get("attention_mask")
119
  if encoder_mask is not None:
120
- encoder_mask = self._expand_attention_mask(encoder_mask, inputs["embeddings"].device)
 
 
121
  enc_out = self.encoder(inputs["embeddings"], mask=encoder_mask)
122
  else:
123
- raise ValueError("inputs must contain 'input_ids' or 'embeddings' for encoder tasks")
 
 
124
  logits = head(enc_out)
125
 
126
  if return_loss:
@@ -152,7 +159,9 @@ class MultiTaskModel(nn.Module):
152
  elif "src_embeddings" in inputs:
153
  memory = self.encoder(inputs["src_embeddings"], mask=encoder_mask)
154
  else:
155
- raise ValueError("inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks")
 
 
156
 
157
  # If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings
158
  if "tgt_ids" in inputs:
@@ -162,7 +171,9 @@ class MultiTaskModel(nn.Module):
162
  else:
163
  # For generation time you may call decoder.greedy_decode separately.
164
  # Here we don't attempt to generate when labels not provided.
165
- raise ValueError("Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward")
 
 
166
 
167
  decoder_out = self.decoder(decoder_inputs, memory, memory_mask=src_mask)
168
 
@@ -209,13 +220,17 @@ class MultiTaskModel(nn.Module):
209
  if isinstance(head, TokenClassificationHead):
210
  # logits: (B, T, C), labels: (B, T)
211
  B, T, C = logits.shape
212
- loss = F.cross_entropy(logits.view(B * T, C), labels.view(B * T).long(), ignore_index=ignore_index)
 
 
213
  return loss
214
 
215
  if isinstance(head, LMHead):
216
  # logits: (B, T, V), labels: (B, T)
217
  B, T, V = logits.shape
218
- loss = F.cross_entropy(logits.view(B * T, V), labels.view(B * T).long(), ignore_index=ignore_index)
 
 
219
  return loss
220
 
221
  # Generic fall-back: try CrossEntropy on final dim
@@ -234,4 +249,4 @@ class MultiTaskModel(nn.Module):
234
  return bool_mask.unsqueeze(1) & bool_mask.unsqueeze(2)
235
  if bool_mask.dim() in (3, 4):
236
  return bool_mask
237
- raise ValueError("Attention mask must be 2D, 3D, or 4D tensor")
 
14
  seq2seq tasks (encoder -> decoder -> LMHead)
15
  - Minimal dependencies on training loop; return logits and (optionally) loss
16
  """
17
+ from typing import Any, Dict, Optional
18
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
 
23
+ from .decoder import TransformerDecoder
24
+
25
  # Import your components
26
  from .encoder import TransformerEncoder
27
+ from .heads import ClassificationHead, LMHead, TokenClassificationHead
 
28
 
29
 
30
  class MultiTaskModel(nn.Module):
 
113
  if "input_ids" in inputs:
114
  encoder_mask = None
115
  if "attention_mask" in inputs:
116
+ encoder_mask = self._expand_attention_mask(
117
+ inputs["attention_mask"], inputs["input_ids"].device
118
+ )
119
  enc_out = self.encoder(inputs["input_ids"], mask=encoder_mask)
120
  elif "embeddings" in inputs:
121
  encoder_mask = inputs.get("attention_mask")
122
  if encoder_mask is not None:
123
+ encoder_mask = self._expand_attention_mask(
124
+ encoder_mask, inputs["embeddings"].device
125
+ )
126
  enc_out = self.encoder(inputs["embeddings"], mask=encoder_mask)
127
  else:
128
+ raise ValueError(
129
+ "inputs must contain 'input_ids' or 'embeddings' for encoder tasks"
130
+ )
131
  logits = head(enc_out)
132
 
133
  if return_loss:
 
159
  elif "src_embeddings" in inputs:
160
  memory = self.encoder(inputs["src_embeddings"], mask=encoder_mask)
161
  else:
162
+ raise ValueError(
163
+ "inputs must contain 'src_ids' or 'src_embeddings' for seq2seq tasks"
164
+ )
165
 
166
  # If training / teacher forcing: expect tgt_ids (shifted by caller) or embeddings
167
  if "tgt_ids" in inputs:
 
171
  else:
172
  # For generation time you may call decoder.greedy_decode separately.
173
  # Here we don't attempt to generate when labels not provided.
174
+ raise ValueError(
175
+ "Seq2seq tasks require 'tgt_ids' or 'tgt_embeddings' for training forward"
176
+ )
177
 
178
  decoder_out = self.decoder(decoder_inputs, memory, memory_mask=src_mask)
179
 
 
220
  if isinstance(head, TokenClassificationHead):
221
  # logits: (B, T, C), labels: (B, T)
222
  B, T, C = logits.shape
223
+ loss = F.cross_entropy(
224
+ logits.view(B * T, C), labels.view(B * T).long(), ignore_index=ignore_index
225
+ )
226
  return loss
227
 
228
  if isinstance(head, LMHead):
229
  # logits: (B, T, V), labels: (B, T)
230
  B, T, V = logits.shape
231
+ loss = F.cross_entropy(
232
+ logits.view(B * T, V), labels.view(B * T).long(), ignore_index=ignore_index
233
+ )
234
  return loss
235
 
236
  # Generic fall-back: try CrossEntropy on final dim
 
249
  return bool_mask.unsqueeze(1) & bool_mask.unsqueeze(2)
250
  if bool_mask.dim() in (3, 4):
251
  return bool_mask
252
+ raise ValueError("Attention mask must be 2D, 3D, or 4D tensor")
src/models/positional_encoding.py CHANGED
@@ -7,31 +7,33 @@ Injects information about the position of tokens in a sequence, since
7
  self-attention has no inherent notion of token order.
8
  """
9
 
 
 
10
  import torch
11
  import torch.nn as nn
12
- import math
13
 
14
  class PositionalEncoding(nn.Module):
15
  """
16
  Implements the sinusoidal positional encoding from "Attention Is All You Need".
17
-
18
  Formula:
19
  PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
20
  PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
21
-
22
  Where:
23
  pos: position in sequence (0 to max_len-1)
24
  i: dimension index (0 to d_model/2)
25
-
26
  Args:
27
  d_model: Dimension of the model embeddings
28
  max_len: Maximum sequence length to pre-compute
29
  dropout: Dropout probability to apply after adding positional encoding
30
-
31
  Shape:
32
  Input: (batch, seq_len, d_model)
33
  Output: (batch, seq_len, d_model)
34
-
35
  Example:
36
  >>> pos_enc = PositionalEncoding(d_model=512, max_len=5000)
37
  >>> x = torch.randn(32, 100, 512) # (batch, seq, d_model)
@@ -39,7 +41,7 @@ class PositionalEncoding(nn.Module):
39
  >>> output.shape
40
  torch.Size([32, 100, 512])
41
  """
42
-
43
  def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
44
  super().__init__()
45
  self.dropout = nn.Dropout(p=dropout)
@@ -49,23 +51,20 @@ class PositionalEncoding(nn.Module):
49
  # Apply sin to even indices, cos to odd indices
50
  # Register as buffer (not a parameter, but part of state_dict)
51
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
52
- div_term = torch.exp(
53
- torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
54
- )
55
  pe = torch.zeros(max_len, d_model)
56
  pe[:, 0::2] = torch.sin(position * div_term) # Even indices
57
  pe[:, 1::2] = torch.cos(position * div_term) # Odd indices
58
  pe = pe.unsqueeze(0)
59
  self.register_buffer("pe", pe)
60
-
61
-
62
  def forward(self, x: torch.Tensor) -> torch.Tensor:
63
  """
64
  Add positional encoding to input embeddings.
65
-
66
  Args:
67
  x: Input embeddings (batch, seq_len, d_model)
68
-
69
  Returns:
70
  x with positional encoding added (batch, seq_len, d_model)
71
  """
@@ -76,4 +75,4 @@ class PositionalEncoding(nn.Module):
76
  x = x + self.pe[:, : x.size(1)].requires_grad_(False)
77
  # self.pe contains pre-computed encodings for all positions
78
  # just need to add the first seq_len positions to x
79
- return self.dropout(x)
 
7
  self-attention has no inherent notion of token order.
8
  """
9
 
10
+ import math
11
+
12
  import torch
13
  import torch.nn as nn
14
+
15
 
16
  class PositionalEncoding(nn.Module):
17
  """
18
  Implements the sinusoidal positional encoding from "Attention Is All You Need".
19
+
20
  Formula:
21
  PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
22
  PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
23
+
24
  Where:
25
  pos: position in sequence (0 to max_len-1)
26
  i: dimension index (0 to d_model/2)
27
+
28
  Args:
29
  d_model: Dimension of the model embeddings
30
  max_len: Maximum sequence length to pre-compute
31
  dropout: Dropout probability to apply after adding positional encoding
32
+
33
  Shape:
34
  Input: (batch, seq_len, d_model)
35
  Output: (batch, seq_len, d_model)
36
+
37
  Example:
38
  >>> pos_enc = PositionalEncoding(d_model=512, max_len=5000)
39
  >>> x = torch.randn(32, 100, 512) # (batch, seq, d_model)
 
41
  >>> output.shape
42
  torch.Size([32, 100, 512])
43
  """
44
+
45
  def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
46
  super().__init__()
47
  self.dropout = nn.Dropout(p=dropout)
 
51
  # Apply sin to even indices, cos to odd indices
52
  # Register as buffer (not a parameter, but part of state_dict)
53
  position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
54
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
 
 
55
  pe = torch.zeros(max_len, d_model)
56
  pe[:, 0::2] = torch.sin(position * div_term) # Even indices
57
  pe[:, 1::2] = torch.cos(position * div_term) # Odd indices
58
  pe = pe.unsqueeze(0)
59
  self.register_buffer("pe", pe)
60
+
 
61
  def forward(self, x: torch.Tensor) -> torch.Tensor:
62
  """
63
  Add positional encoding to input embeddings.
64
+
65
  Args:
66
  x: Input embeddings (batch, seq_len, d_model)
67
+
68
  Returns:
69
  x with positional encoding added (batch, seq_len, d_model)
70
  """
 
75
  x = x + self.pe[:, : x.size(1)].requires_grad_(False)
76
  # self.pe contains pre-computed encodings for all positions
77
  # just need to add the first seq_len positions to x
78
+ return self.dropout(x)
src/training/utils.py CHANGED
@@ -9,7 +9,6 @@ from typing import Optional
9
  import numpy as np
10
  import torch
11
 
12
-
13
  _seed_sequence: Optional[np.random.SeedSequence] = None
14
  _seed_lock = threading.Lock()
15
  _spawn_counter = 0
@@ -33,4 +32,3 @@ def set_seed(seed: int) -> np.random.Generator:
33
  _spawn_counter = 1
34
  _thread_local.rng = rng
35
  return rng
36
-
 
9
  import numpy as np
10
  import torch
11
 
 
12
  _seed_sequence: Optional[np.random.SeedSequence] = None
13
  _seed_lock = threading.Lock()
14
  _spawn_counter = 0
 
32
  _spawn_counter = 1
33
  _thread_local.rng = rng
34
  return rng
 
src/visualization/embeddings.py CHANGED
@@ -1,9 +1,9 @@
1
  """Embedding visualization helpers."""
2
 
3
  import matplotlib.pyplot as plt
 
4
  import pandas as pd
5
  import seaborn as sns
6
- import numpy as np
7
  from sklearn.manifold import TSNE
8
 
9
 
@@ -16,15 +16,17 @@ def plot_tsne(embeddings: np.ndarray, labels: list[str]) -> None:
16
  raise ValueError("number of samples in embeddings must equal length of labels")
17
  if embeddings.shape[1] < 2:
18
  raise ValueError("embeddings must have at least 2 features for t-SNE visualization")
19
-
20
  reducer = TSNE(n_components=2, init="pca", learning_rate="auto")
21
  projection = reducer.fit_transform(embeddings)
22
 
23
- df = pd.DataFrame({
24
- "x": projection[:, 0],
25
- "y": projection[:, 1],
26
- "label": labels,
27
- })
 
 
28
  plt.figure()
29
  sns.scatterplot(data=df, x="x", y="y", hue="label", palette="tab10", s=50)
30
  plt.legend(title="Labels", loc="best")
 
1
  """Embedding visualization helpers."""
2
 
3
  import matplotlib.pyplot as plt
4
+ import numpy as np
5
  import pandas as pd
6
  import seaborn as sns
 
7
  from sklearn.manifold import TSNE
8
 
9
 
 
16
  raise ValueError("number of samples in embeddings must equal length of labels")
17
  if embeddings.shape[1] < 2:
18
  raise ValueError("embeddings must have at least 2 features for t-SNE visualization")
19
+
20
  reducer = TSNE(n_components=2, init="pca", learning_rate="auto")
21
  projection = reducer.fit_transform(embeddings)
22
 
23
+ df = pd.DataFrame(
24
+ {
25
+ "x": projection[:, 0],
26
+ "y": projection[:, 1],
27
+ "label": labels,
28
+ }
29
+ )
30
  plt.figure()
31
  sns.scatterplot(data=df, x="x", y="y", hue="label", palette="tab10", s=50)
32
  plt.legend(title="Labels", loc="best")
tests/test_models/test_decoder_step.py CHANGED
@@ -1,6 +1,7 @@
1
- import torch
2
- import pytest
3
  from typing import Any, Dict, cast
 
 
 
4
  from src.models.decoder import TransformerDecoder
5
 
6
 
@@ -93,6 +94,5 @@ def test_step_cache_growth_and_shapes():
93
  for i in range(num_layers):
94
  assert f"mem_k_{i}" in cache and f"mem_v_{i}" in cache
95
  mem_k = cache[f"mem_k_{i}"]
96
- mem_v = cache[f"mem_v_{i}"]
97
  assert mem_k.shape[0] == batch_size
98
- assert mem_k.shape[2] == src_len # seq length of memory
 
 
 
1
  from typing import Any, Dict, cast
2
+
3
+ import torch
4
+
5
  from src.models.decoder import TransformerDecoder
6
 
7
 
 
94
  for i in range(num_layers):
95
  assert f"mem_k_{i}" in cache and f"mem_v_{i}" in cache
96
  mem_k = cache[f"mem_k_{i}"]
 
97
  assert mem_k.shape[0] == batch_size
98
+ assert mem_k.shape[2] == src_len # seq length of memory
tests/test_models/test_encoder.py CHANGED
@@ -1,6 +1,6 @@
1
- import math
2
- import torch
3
  import pytest
 
 
4
  from src.models.encoder import TransformerEncoder
5
 
6
 
@@ -173,4 +173,4 @@ def test_train_eval_determinism_and_dropout_effect():
173
 
174
 
175
  if __name__ == "__main__":
176
- pytest.main([__file__, "-q"])
 
 
 
1
  import pytest
2
+ import torch
3
+
4
  from src.models.encoder import TransformerEncoder
5
 
6
 
 
173
 
174
 
175
  if __name__ == "__main__":
176
+ pytest.main([__file__, "-q"])
tests/test_models/test_encoder_layer.py CHANGED
@@ -1,5 +1,6 @@
1
- import torch
2
  import pytest
 
 
3
  from src.models.encoder import TransformerEncoderLayer
4
 
5
 
@@ -83,4 +84,4 @@ def test_mask_broadcasting_accepts_3d_and_4d_mask():
83
 
84
  if __name__ == "__main__":
85
  # Run tests interactively if needed
86
- pytest.main([__file__, "-q"])
 
 
1
  import pytest
2
+ import torch
3
+
4
  from src.models.encoder import TransformerEncoderLayer
5
 
6
 
 
84
 
85
  if __name__ == "__main__":
86
  # Run tests interactively if needed
87
+ pytest.main([__file__, "-q"])
tests/test_models/test_feedforward.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- import pytest
3
  from src.models.feedforward import FeedForward
4
 
5
 
@@ -38,15 +38,15 @@ class TestFeedForward:
38
 
39
  # Parameter existence
40
  param_names = [name for name, _ in ffn.named_parameters()]
41
- assert any('linear1' in name for name in param_names)
42
- assert any('linear2' in name for name in param_names)
43
 
44
  # Parameter shapes
45
  shapes = {name: p.shape for name, p in ffn.named_parameters()}
46
- assert shapes.get('linear1.weight') == (d_ff, d_model)
47
- assert shapes.get('linear2.weight') == (d_model, d_ff)
48
- assert shapes.get('linear1.bias') == (d_ff,)
49
- assert shapes.get('linear2.bias') == (d_model,)
50
 
51
  # ensure gradients flow
52
  x = torch.randn(3, 5, d_model)
@@ -54,4 +54,4 @@ class TestFeedForward:
54
  loss = out.sum()
55
  loss.backward()
56
  for _, p in ffn.named_parameters():
57
- assert p.grad is not None
 
1
  import torch
2
+
3
  from src.models.feedforward import FeedForward
4
 
5
 
 
38
 
39
  # Parameter existence
40
  param_names = [name for name, _ in ffn.named_parameters()]
41
+ assert any("linear1" in name for name in param_names)
42
+ assert any("linear2" in name for name in param_names)
43
 
44
  # Parameter shapes
45
  shapes = {name: p.shape for name, p in ffn.named_parameters()}
46
+ assert shapes.get("linear1.weight") == (d_ff, d_model)
47
+ assert shapes.get("linear2.weight") == (d_model, d_ff)
48
+ assert shapes.get("linear1.bias") == (d_ff,)
49
+ assert shapes.get("linear2.bias") == (d_model,)
50
 
51
  # ensure gradients flow
52
  x = torch.randn(3, 5, d_model)
 
54
  loss = out.sum()
55
  loss.backward()
56
  for _, p in ffn.named_parameters():
57
+ assert p.grad is not None
tests/test_models/test_heads.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
- import pytest
3
  import torch.nn as nn
 
4
  from src.models.heads import (
5
  ClassificationHead,
6
- TokenClassificationHead,
7
  LMHead,
8
  ProjectionHead,
 
9
  )
10
 
11
 
@@ -101,4 +101,4 @@ def test_projection_head_2d_and_3d_behavior_and_grad():
101
  loss = out3.sum()
102
  loss.backward()
103
  grads = [p.grad for p in head.parameters() if p.requires_grad]
104
- assert any(g is not None for g in grads)
 
1
  import torch
 
2
  import torch.nn as nn
3
+
4
  from src.models.heads import (
5
  ClassificationHead,
 
6
  LMHead,
7
  ProjectionHead,
8
+ TokenClassificationHead,
9
  )
10
 
11
 
 
101
  loss = out3.sum()
102
  loss.backward()
103
  grads = [p.grad for p in head.parameters() if p.requires_grad]
104
+ assert any(g is not None for g in grads)
tests/test_models/test_multitask.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
- import pytest
3
- from src.models.encoder import TransformerEncoder
4
  from src.models.decoder import TransformerDecoder
 
5
  from src.models.heads import ClassificationHead, LMHead, TokenClassificationHead
6
  from src.models.multitask import MultiTaskModel
7
 
@@ -17,8 +17,16 @@ def test_multitask_encoder_classification_forward_and_loss():
17
  seq_len = 8
18
  num_labels = 5
19
 
20
- enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
21
- num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0)
 
 
 
 
 
 
 
 
22
 
23
  mt = MultiTaskModel(encoder=enc)
24
  head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.0)
@@ -30,7 +38,9 @@ def test_multitask_encoder_classification_forward_and_loss():
30
  logits = mt.forward("sentiment", {"input_ids": input_ids})
31
  assert logits.shape == (batch_size, num_labels)
32
 
33
- loss, logits2 = mt.forward("sentiment", {"input_ids": input_ids, "labels": labels}, return_loss=True)
 
 
34
  assert loss.item() >= 0
35
  # grads
36
  loss.backward()
@@ -49,10 +59,26 @@ def test_multitask_seq2seq_lm_forward_and_loss():
49
  src_len = 7
50
  tgt_len = 6
51
 
52
- enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
53
- num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=src_len, pad_token_id=0)
54
- dec = TransformerDecoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
55
- num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=tgt_len, pad_token_id=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  mt = MultiTaskModel(encoder=enc, decoder=dec)
57
  lm_head = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
58
  mt.add_head("summarize", lm_head)
@@ -65,7 +91,9 @@ def test_multitask_seq2seq_lm_forward_and_loss():
65
  logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
66
  assert logits.shape == (batch_size, tgt_len, vocab_size)
67
 
68
- loss, logits2 = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids, "labels": labels}, return_loss=True)
 
 
69
  assert loss.item() >= 0
70
  loss.backward()
71
  grads = [p.grad for p in mt.parameters() if p.requires_grad]
@@ -83,8 +111,16 @@ def test_token_classification_forward_and_loss():
83
  seq_len = 5
84
  num_labels = 7
85
 
86
- enc = TransformerEncoder(vocab_size=vocab_size, d_model=d_model, num_layers=num_layers,
87
- num_heads=num_heads, d_ff=d_ff, dropout=0.0, max_len=seq_len, pad_token_id=0)
 
 
 
 
 
 
 
 
88
  mt = MultiTaskModel(encoder=enc)
89
  head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
90
  mt.add_head("ner", head)
@@ -99,4 +135,4 @@ def test_token_classification_forward_and_loss():
99
  assert loss.item() >= 0
100
  loss.backward()
101
  grads = [p.grad for p in mt.parameters() if p.requires_grad]
102
- assert any(g is not None for g in grads)
 
1
  import torch
2
+
 
3
  from src.models.decoder import TransformerDecoder
4
+ from src.models.encoder import TransformerEncoder
5
  from src.models.heads import ClassificationHead, LMHead, TokenClassificationHead
6
  from src.models.multitask import MultiTaskModel
7
 
 
17
  seq_len = 8
18
  num_labels = 5
19
 
20
+ enc = TransformerEncoder(
21
+ vocab_size=vocab_size,
22
+ d_model=d_model,
23
+ num_layers=num_layers,
24
+ num_heads=num_heads,
25
+ d_ff=d_ff,
26
+ dropout=0.0,
27
+ max_len=seq_len,
28
+ pad_token_id=0,
29
+ )
30
 
31
  mt = MultiTaskModel(encoder=enc)
32
  head = ClassificationHead(d_model=d_model, num_labels=num_labels, pooler="mean", dropout=0.0)
 
38
  logits = mt.forward("sentiment", {"input_ids": input_ids})
39
  assert logits.shape == (batch_size, num_labels)
40
 
41
+ loss, logits2 = mt.forward(
42
+ "sentiment", {"input_ids": input_ids, "labels": labels}, return_loss=True
43
+ )
44
  assert loss.item() >= 0
45
  # grads
46
  loss.backward()
 
59
  src_len = 7
60
  tgt_len = 6
61
 
62
+ enc = TransformerEncoder(
63
+ vocab_size=vocab_size,
64
+ d_model=d_model,
65
+ num_layers=num_layers,
66
+ num_heads=num_heads,
67
+ d_ff=d_ff,
68
+ dropout=0.0,
69
+ max_len=src_len,
70
+ pad_token_id=0,
71
+ )
72
+ dec = TransformerDecoder(
73
+ vocab_size=vocab_size,
74
+ d_model=d_model,
75
+ num_layers=num_layers,
76
+ num_heads=num_heads,
77
+ d_ff=d_ff,
78
+ dropout=0.0,
79
+ max_len=tgt_len,
80
+ pad_token_id=0,
81
+ )
82
  mt = MultiTaskModel(encoder=enc, decoder=dec)
83
  lm_head = LMHead(d_model=d_model, vocab_size=vocab_size, tie_embedding=None)
84
  mt.add_head("summarize", lm_head)
 
91
  logits = mt.forward("summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids})
92
  assert logits.shape == (batch_size, tgt_len, vocab_size)
93
 
94
+ loss, logits2 = mt.forward(
95
+ "summarize", {"src_ids": src_ids, "tgt_ids": tgt_ids, "labels": labels}, return_loss=True
96
+ )
97
  assert loss.item() >= 0
98
  loss.backward()
99
  grads = [p.grad for p in mt.parameters() if p.requires_grad]
 
111
  seq_len = 5
112
  num_labels = 7
113
 
114
+ enc = TransformerEncoder(
115
+ vocab_size=vocab_size,
116
+ d_model=d_model,
117
+ num_layers=num_layers,
118
+ num_heads=num_heads,
119
+ d_ff=d_ff,
120
+ dropout=0.0,
121
+ max_len=seq_len,
122
+ pad_token_id=0,
123
+ )
124
  mt = MultiTaskModel(encoder=enc)
125
  head = TokenClassificationHead(d_model=d_model, num_labels=num_labels, dropout=0.0)
126
  mt.add_head("ner", head)
 
135
  assert loss.item() >= 0
136
  loss.backward()
137
  grads = [p.grad for p in mt.parameters() if p.requires_grad]
138
+ assert any(g is not None for g in grads)