OliverPerrin commited on
Commit
29f2de2
·
1 Parent(s): f9d964d

Style: Apply ruff formatting

Browse files
scripts/demo_gradio.py CHANGED
@@ -2,6 +2,7 @@
2
  Minimal Gradio demo for the LexiMind multitask model.
3
  Shows raw model outputs without any post-processing tricks.
4
  """
 
5
  from __future__ import annotations
6
 
7
  import json
 
2
  Minimal Gradio demo for the LexiMind multitask model.
3
  Shows raw model outputs without any post-processing tricks.
4
  """
5
+
6
  from __future__ import annotations
7
 
8
  import json
scripts/download_data.py CHANGED
@@ -1,4 +1,5 @@
1
  """Download datasets used by LexiMind."""
 
2
  from __future__ import annotations
3
 
4
  import argparse
 
1
  """Download datasets used by LexiMind."""
2
+
3
  from __future__ import annotations
4
 
5
  import argparse
scripts/eval_rouge.py CHANGED
@@ -1,4 +1,5 @@
1
  """Utility script to evaluate LexiMind summaries with ROUGE."""
 
2
  from __future__ import annotations
3
 
4
  import argparse
 
1
  """Utility script to evaluate LexiMind summaries with ROUGE."""
2
+
3
  from __future__ import annotations
4
 
5
  import argparse
scripts/evaluate.py CHANGED
@@ -2,6 +2,7 @@
2
  Evaluate the multitask model on processed validation/test splits.
3
  This is used for getting definitive scores on my test set after training is complete.
4
  """
 
5
  from __future__ import annotations
6
 
7
  import argparse
 
2
  Evaluate the multitask model on processed validation/test splits.
3
  This is used for getting definitive scores on my test set after training is complete.
4
  """
5
+
6
  from __future__ import annotations
7
 
8
  import argparse
scripts/export_model.py CHANGED
@@ -1,4 +1,5 @@
1
  """Rebuild and export the trained multitask model for downstream use."""
 
2
  from __future__ import annotations
3
 
4
  import argparse
@@ -14,11 +15,27 @@ from src.utils.labels import load_label_metadata
14
 
15
  def parse_args() -> argparse.Namespace:
16
  parser = argparse.ArgumentParser(description="Export LexiMind model weights")
17
- parser.add_argument("--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint.")
18
- parser.add_argument("--output", default="outputs/model.pt", help="Output path for the exported state dict.")
19
- parser.add_argument("--labels", default="artifacts/labels.json", help="Label metadata JSON produced after training.")
20
- parser.add_argument("--model-config", default="configs/model/base.yaml", help="Model architecture configuration.")
21
- parser.add_argument("--data-config", default="configs/data/datasets.yaml", help="Data configuration (for tokenizer settings).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  return parser.parse_args()
23
 
24
 
 
1
  """Rebuild and export the trained multitask model for downstream use."""
2
+
3
  from __future__ import annotations
4
 
5
  import argparse
 
15
 
16
  def parse_args() -> argparse.Namespace:
17
  parser = argparse.ArgumentParser(description="Export LexiMind model weights")
18
+ parser.add_argument(
19
+ "--checkpoint", default="checkpoints/best.pt", help="Path to the trained checkpoint."
20
+ )
21
+ parser.add_argument(
22
+ "--output", default="outputs/model.pt", help="Output path for the exported state dict."
23
+ )
24
+ parser.add_argument(
25
+ "--labels",
26
+ default="artifacts/labels.json",
27
+ help="Label metadata JSON produced after training.",
28
+ )
29
+ parser.add_argument(
30
+ "--model-config",
31
+ default="configs/model/base.yaml",
32
+ help="Model architecture configuration.",
33
+ )
34
+ parser.add_argument(
35
+ "--data-config",
36
+ default="configs/data/datasets.yaml",
37
+ help="Data configuration (for tokenizer settings).",
38
+ )
39
  return parser.parse_args()
40
 
41
 
scripts/inference.py CHANGED
@@ -1,4 +1,5 @@
1
  """Run inference with the multitask model."""
 
2
  from __future__ import annotations
3
 
4
  import argparse
 
1
  """Run inference with the multitask model."""
2
+
3
  from __future__ import annotations
4
 
5
  import argparse
scripts/preprocess_data.py CHANGED
@@ -1,4 +1,5 @@
1
  """Preprocess raw datasets into JSONL splits for LexiMind training."""
 
2
  from __future__ import annotations
3
 
4
  import argparse
@@ -139,9 +140,10 @@ def preprocess_summarization(raw_dir: Path, processed_dir: Path) -> None:
139
  output_path = processed_dir / f"{split}.jsonl"
140
  output_path.parent.mkdir(parents=True, exist_ok=True)
141
  print(f"Writing summarization split '{split}' to {output_path}")
142
- with source_path.open("r", encoding="utf-8", newline="") as source_handle, output_path.open(
143
- "w", encoding="utf-8"
144
- ) as sink:
 
145
  reader = csv.DictReader(source_handle)
146
  for row in reader:
147
  article = row.get("article") or row.get("Article") or ""
 
1
  """Preprocess raw datasets into JSONL splits for LexiMind training."""
2
+
3
  from __future__ import annotations
4
 
5
  import argparse
 
140
  output_path = processed_dir / f"{split}.jsonl"
141
  output_path.parent.mkdir(parents=True, exist_ok=True)
142
  print(f"Writing summarization split '{split}' to {output_path}")
143
+ with (
144
+ source_path.open("r", encoding="utf-8", newline="") as source_handle,
145
+ output_path.open("w", encoding="utf-8") as sink,
146
+ ):
147
  reader = csv.DictReader(source_handle)
148
  for row in reader:
149
  article = row.get("article") or row.get("Article") or ""
scripts/train.py CHANGED
@@ -1,4 +1,5 @@
1
  """End-to-end training entrypoint for the LexiMind multitask model."""
 
2
  from __future__ import annotations
3
 
4
  import json
 
1
  """End-to-end training entrypoint for the LexiMind multitask model."""
2
+
3
  from __future__ import annotations
4
 
5
  import json
src/api/app.py CHANGED
@@ -1,4 +1,5 @@
1
  """FastAPI application entrypoint."""
 
2
  from fastapi import FastAPI
3
 
4
  from .routes import router
 
1
  """FastAPI application entrypoint."""
2
+
3
  from fastapi import FastAPI
4
 
5
  from .routes import router
src/api/dependencies.py CHANGED
@@ -1,4 +1,5 @@
1
  """Dependency providers for the FastAPI application."""
 
2
  from __future__ import annotations
3
 
4
  from functools import lru_cache
 
1
  """Dependency providers for the FastAPI application."""
2
+
3
  from __future__ import annotations
4
 
5
  from functools import lru_cache
src/api/routes.py CHANGED
@@ -1,4 +1,5 @@
1
  """API routes."""
 
2
  from typing import cast
3
 
4
  from fastapi import APIRouter, Depends, HTTPException, status
 
1
  """API routes."""
2
+
3
  from typing import cast
4
 
5
  from fastapi import APIRouter, Depends, HTTPException, status
src/api/schemas.py CHANGED
@@ -1,4 +1,5 @@
1
  """API schemas."""
 
2
  from pydantic import BaseModel
3
 
4
 
 
1
  """API schemas."""
2
+
3
  from pydantic import BaseModel
4
 
5
 
src/data/dataloader.py CHANGED
@@ -1,4 +1,5 @@
1
  """Task-aware DataLoader builders for the LexiMind multitask suite."""
 
2
  from __future__ import annotations
3
 
4
  from typing import List
 
1
  """Task-aware DataLoader builders for the LexiMind multitask suite."""
2
+
3
  from __future__ import annotations
4
 
5
  from typing import List
src/data/dataset.py CHANGED
@@ -1,4 +1,5 @@
1
  """Dataset definitions for the LexiMind multitask training pipeline."""
 
2
  from __future__ import annotations
3
 
4
  import json
@@ -179,7 +180,9 @@ def _load_jsonl_generic(
179
  if first_non_ws == "[":
180
  payloads = _safe_json_load(handle, data_path)
181
  if not isinstance(payloads, list):
182
- raise ValueError(f"Expected a JSON array in '{data_path}' but found {type(payloads).__name__}")
 
 
183
  for idx, payload in enumerate(payloads):
184
  if not isinstance(payload, dict):
185
  raise ValueError(
 
1
  """Dataset definitions for the LexiMind multitask training pipeline."""
2
+
3
  from __future__ import annotations
4
 
5
  import json
 
180
  if first_non_ws == "[":
181
  payloads = _safe_json_load(handle, data_path)
182
  if not isinstance(payloads, list):
183
+ raise ValueError(
184
+ f"Expected a JSON array in '{data_path}' but found {type(payloads).__name__}"
185
+ )
186
  for idx, payload in enumerate(payloads):
187
  if not isinstance(payload, dict):
188
  raise ValueError(
src/data/preprocessing.py CHANGED
@@ -1,4 +1,5 @@
1
  """Text preprocessing utilities built around Hugging Face tokenizers."""
 
2
  from __future__ import annotations
3
 
4
  from dataclasses import dataclass, replace
 
1
  """Text preprocessing utilities built around Hugging Face tokenizers."""
2
+
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass, replace
src/data/tokenization.py CHANGED
@@ -1,4 +1,5 @@
1
  """Tokenizer wrapper around HuggingFace models used across LexiMind."""
 
2
  from __future__ import annotations
3
 
4
  from dataclasses import dataclass
@@ -23,13 +24,19 @@ class Tokenizer:
23
  def __init__(self, config: TokenizerConfig | None = None) -> None:
24
  cfg = config or TokenizerConfig()
25
  self.config = cfg
26
- self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(cfg.pretrained_model_name)
 
 
27
  self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
28
  self._bos_token_id = self._resolve_id(
29
- self._tokenizer.bos_token_id if self._tokenizer.bos_token_id is not None else self._tokenizer.cls_token_id
 
 
30
  )
31
  self._eos_token_id = self._resolve_id(
32
- self._tokenizer.eos_token_id if self._tokenizer.eos_token_id is not None else self._tokenizer.sep_token_id
 
 
33
  )
34
 
35
  @property
@@ -84,7 +91,9 @@ class Tokenizer:
84
  )
85
  return cast(List[List[int]], encoded["input_ids"])
86
 
87
- def batch_encode(self, texts: Sequence[str], *, max_length: int | None = None) -> dict[str, torch.Tensor]:
 
 
88
  normalized = [text.lower() if self.config.lower else text for text in texts]
89
  encoded = self._tokenizer(
90
  normalized,
 
1
  """Tokenizer wrapper around HuggingFace models used across LexiMind."""
2
+
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass
 
24
  def __init__(self, config: TokenizerConfig | None = None) -> None:
25
  cfg = config or TokenizerConfig()
26
  self.config = cfg
27
+ self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
28
+ cfg.pretrained_model_name
29
+ )
30
  self._pad_token_id = self._resolve_id(self._tokenizer.pad_token_id)
31
  self._bos_token_id = self._resolve_id(
32
+ self._tokenizer.bos_token_id
33
+ if self._tokenizer.bos_token_id is not None
34
+ else self._tokenizer.cls_token_id
35
  )
36
  self._eos_token_id = self._resolve_id(
37
+ self._tokenizer.eos_token_id
38
+ if self._tokenizer.eos_token_id is not None
39
+ else self._tokenizer.sep_token_id
40
  )
41
 
42
  @property
 
91
  )
92
  return cast(List[List[int]], encoded["input_ids"])
93
 
94
+ def batch_encode(
95
+ self, texts: Sequence[str], *, max_length: int | None = None
96
+ ) -> dict[str, torch.Tensor]:
97
  normalized = [text.lower() if self.config.lower else text for text in texts]
98
  encoded = self._tokenizer(
99
  normalized,
src/inference/__init__.py CHANGED
@@ -4,9 +4,9 @@ from .factory import create_inference_pipeline
4
  from .pipeline import EmotionPrediction, InferenceConfig, InferencePipeline, TopicPrediction
5
 
6
  __all__ = [
7
- "InferencePipeline",
8
- "InferenceConfig",
9
- "EmotionPrediction",
10
- "TopicPrediction",
11
- "create_inference_pipeline",
12
  ]
 
4
  from .pipeline import EmotionPrediction, InferenceConfig, InferencePipeline, TopicPrediction
5
 
6
  __all__ = [
7
+ "InferencePipeline",
8
+ "InferenceConfig",
9
+ "EmotionPrediction",
10
+ "TopicPrediction",
11
+ "create_inference_pipeline",
12
  ]
src/inference/factory.py CHANGED
@@ -1,4 +1,5 @@
1
  """Helpers to assemble an inference pipeline from saved artifacts."""
 
2
  from __future__ import annotations
3
 
4
  from pathlib import Path
 
1
  """Helpers to assemble an inference pipeline from saved artifacts."""
2
+
3
  from __future__ import annotations
4
 
5
  from pathlib import Path
src/inference/pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  """Inference helpers for multitask LexiMind models."""
 
2
  from __future__ import annotations
3
 
4
  from dataclasses import dataclass, fields, replace
 
1
  """Inference helpers for multitask LexiMind models."""
2
+
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass, fields, replace
src/inference/postprocessing.py CHANGED
@@ -1,4 +1,5 @@
1
  """Output cleaning helpers."""
 
2
  from typing import List
3
 
4
 
 
1
  """Output cleaning helpers."""
2
+
3
  from typing import List
4
 
5
 
src/models/decoder.py CHANGED
@@ -12,6 +12,7 @@ Conventions:
12
  - This decoder uses Pre-LN (RMSNorm before each sublayer).
13
  - RMSNorm is just simpler than LayerNorm and more computationally efficient, it's become the modern convention. These reasons are why I used it here.
14
  """
 
15
  import math
16
  from typing import Dict, List, Optional, Tuple, Union
17
 
 
12
  - This decoder uses Pre-LN (RMSNorm before each sublayer).
13
  - RMSNorm is just simpler than LayerNorm and more computationally efficient, it's become the modern convention. These reasons are why I used it here.
14
  """
15
+
16
  import math
17
  from typing import Dict, List, Optional, Tuple, Union
18
 
src/models/factory.py CHANGED
@@ -1,4 +1,5 @@
1
  """Factory helpers to assemble multitask models for inference/training."""
 
2
  from __future__ import annotations
3
 
4
  from dataclasses import dataclass
 
1
  """Factory helpers to assemble multitask models for inference/training."""
2
+
3
  from __future__ import annotations
4
 
5
  from dataclasses import dataclass
src/models/heads.py CHANGED
@@ -9,6 +9,7 @@ Includes:
9
 
10
  Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
11
  """
 
12
  from typing import Literal, Optional
13
 
14
  import torch
 
9
 
10
  Keep these heads minimal, well-tested, and easy to compose on top of encoder/decoder outputs.
11
  """
12
+
13
  from typing import Literal, Optional
14
 
15
  import torch
src/models/multitask.py CHANGED
@@ -14,6 +14,7 @@ Design goals:
14
  seq2seq tasks (encoder -> decoder -> LMHead)
15
  - Minimal dependencies on training loop; return logits and (optionally) loss
16
  """
 
17
  from typing import Any, Dict, Optional
18
 
19
  import torch
 
14
  seq2seq tasks (encoder -> decoder -> LMHead)
15
  - Minimal dependencies on training loop; return logits and (optionally) loss
16
  """
17
+
18
  from typing import Any, Dict, Optional
19
 
20
  import torch
src/training/metrics.py CHANGED
@@ -1,4 +1,5 @@
1
  """Metric helpers used during training and evaluation."""
 
2
  from __future__ import annotations
3
 
4
  from typing import Any, Dict, List, Sequence
 
1
  """Metric helpers used during training and evaluation."""
2
+
3
  from __future__ import annotations
4
 
5
  from typing import Any, Dict, List, Sequence
src/training/trainer.py CHANGED
@@ -1,4 +1,5 @@
1
  """Multi-task trainer coordinating summarization, emotion, and topic heads."""
 
2
  from __future__ import annotations
3
 
4
  import shutil
@@ -330,9 +331,9 @@ class Trainer:
330
  """Generate and print sample summaries to monitor quality during training."""
331
  self.model.eval()
332
  samples_generated = 0
333
- print(f"\n{'='*80}")
334
  print(f"[Validation Generation - Epoch {epoch}]")
335
- print(f"{'='*80}")
336
 
337
  with torch.no_grad():
338
  for batch in val_loader:
@@ -400,7 +401,7 @@ class Trainer:
400
 
401
  samples_generated += 1
402
 
403
- print(f"{'='*80}\n")
404
  self.model.train()
405
 
406
  def _print_epoch_progress(
 
1
  """Multi-task trainer coordinating summarization, emotion, and topic heads."""
2
+
3
  from __future__ import annotations
4
 
5
  import shutil
 
331
  """Generate and print sample summaries to monitor quality during training."""
332
  self.model.eval()
333
  samples_generated = 0
334
+ print(f"\n{'=' * 80}")
335
  print(f"[Validation Generation - Epoch {epoch}]")
336
+ print(f"{'=' * 80}")
337
 
338
  with torch.no_grad():
339
  for batch in val_loader:
 
401
 
402
  samples_generated += 1
403
 
404
+ print(f"{'=' * 80}\n")
405
  self.model.train()
406
 
407
  def _print_epoch_progress(
src/utils/config.py CHANGED
@@ -1,4 +1,5 @@
1
  """YAML config loader."""
 
2
  from dataclasses import dataclass
3
  from pathlib import Path
4
  from typing import Any, Dict
 
1
  """YAML config loader."""
2
+
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Any, Dict
src/utils/io.py CHANGED
@@ -1,4 +1,5 @@
1
  """Checkpoint IO helpers."""
 
2
  from pathlib import Path
3
 
4
  import torch
@@ -12,4 +13,4 @@ def save_state(model: torch.nn.Module, path: str) -> None:
12
 
13
  def load_state(model: torch.nn.Module, path: str) -> None:
14
  state = torch.load(path, map_location="cpu", weights_only=True)
15
- model.load_state_dict(state)
 
1
  """Checkpoint IO helpers."""
2
+
3
  from pathlib import Path
4
 
5
  import torch
 
13
 
14
  def load_state(model: torch.nn.Module, path: str) -> None:
15
  state = torch.load(path, map_location="cpu", weights_only=True)
16
+ model.load_state_dict(state)
src/utils/labels.py CHANGED
@@ -1,4 +1,5 @@
1
  """Label metadata helpers for multitask inference."""
 
2
  from __future__ import annotations
3
 
4
  import json
 
1
  """Label metadata helpers for multitask inference."""
2
+
3
  from __future__ import annotations
4
 
5
  import json
src/utils/logging.py CHANGED
@@ -1,4 +1,5 @@
1
  """Logging setup."""
 
2
  import logging
3
 
4
 
 
1
  """Logging setup."""
2
+
3
  import logging
4
 
5
 
src/utils/random.py CHANGED
@@ -1,4 +1,5 @@
1
  """Randomness helpers."""
 
2
  import random
3
 
4
  import numpy as np
 
1
  """Randomness helpers."""
2
+
3
  import random
4
 
5
  import numpy as np
src/visualization/attention.py CHANGED
@@ -1,4 +1,5 @@
1
  """Attention plotting utilities."""
 
2
  from typing import Sequence
3
 
4
  import matplotlib.pyplot as plt
 
1
  """Attention plotting utilities."""
2
+
3
  from typing import Sequence
4
 
5
  import matplotlib.pyplot as plt
src/visualization/metrics.py CHANGED
@@ -1,4 +1,5 @@
1
  """Metric plotting helpers."""
 
2
  import matplotlib.pyplot as plt
3
 
4
 
 
1
  """Metric plotting helpers."""
2
+
3
  import matplotlib.pyplot as plt
4
 
5
 
tests/test_api/test_routes.py CHANGED
@@ -1,4 +1,5 @@
1
  """API integration tests for the inference endpoint."""
 
2
  from __future__ import annotations
3
 
4
  from fastapi.testclient import TestClient
@@ -31,4 +32,4 @@ def test_summarize_route_returns_pipeline_outputs() -> None:
31
  assert payload["topic"] == "news"
32
  assert payload["topic_confidence"] == 0.8
33
  finally:
34
- app.dependency_overrides.clear()
 
1
  """API integration tests for the inference endpoint."""
2
+
3
  from __future__ import annotations
4
 
5
  from fastapi.testclient import TestClient
 
32
  assert payload["topic"] == "news"
33
  assert payload["topic_confidence"] == 0.8
34
  finally:
35
+ app.dependency_overrides.clear()
tests/test_data/test_download_records.py CHANGED
@@ -1,4 +1,5 @@
1
  """Unit tests for dataset record helpers in scripts.download_data."""
 
2
  from __future__ import annotations
3
 
4
  import importlib.util
@@ -26,11 +27,13 @@ class DummyDataset:
26
 
27
  class DownloadDataRecordTests(unittest.TestCase):
28
  def test_emotion_records_handles_out_of_range_labels(self) -> None:
29
- dataset_split = DummyDataset([
30
- {"text": "sample", "label": 1},
31
- {"text": "multi", "label": [0, 5]},
32
- {"text": "string", "label": "2"},
33
- ])
 
 
34
  label_names = ["sadness", "joy", "love"]
35
  records = list(
36
  download_data._emotion_records(
@@ -45,12 +48,14 @@ class DownloadDataRecordTests(unittest.TestCase):
45
  self.assertEqual(records[2]["emotions"], ["2"])
46
 
47
  def test_topic_records_handles_varied_label_inputs(self) -> None:
48
- dataset_split = DummyDataset([
49
- {"text": "news", "label": 3},
50
- {"text": "list", "label": [1]},
51
- {"text": "unknown", "label": "5"},
52
- {"text": "missing", "label": []},
53
- ])
 
 
54
  label_names = ["World", "Sports", "Business", "Sci/Tech"]
55
  records = list(
56
  download_data._topic_records(
 
1
  """Unit tests for dataset record helpers in scripts.download_data."""
2
+
3
  from __future__ import annotations
4
 
5
  import importlib.util
 
27
 
28
  class DownloadDataRecordTests(unittest.TestCase):
29
  def test_emotion_records_handles_out_of_range_labels(self) -> None:
30
+ dataset_split = DummyDataset(
31
+ [
32
+ {"text": "sample", "label": 1},
33
+ {"text": "multi", "label": [0, 5]},
34
+ {"text": "string", "label": "2"},
35
+ ]
36
+ )
37
  label_names = ["sadness", "joy", "love"]
38
  records = list(
39
  download_data._emotion_records(
 
48
  self.assertEqual(records[2]["emotions"], ["2"])
49
 
50
  def test_topic_records_handles_varied_label_inputs(self) -> None:
51
+ dataset_split = DummyDataset(
52
+ [
53
+ {"text": "news", "label": 3},
54
+ {"text": "list", "label": [1]},
55
+ {"text": "unknown", "label": "5"},
56
+ {"text": "missing", "label": []},
57
+ ]
58
+ )
59
  label_names = ["World", "Sports", "Business", "Sci/Tech"]
60
  records = list(
61
  download_data._topic_records(
tests/test_inference/test_pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  """Integration tests for the inference pipeline."""
 
2
  from __future__ import annotations
3
 
4
  from pathlib import Path
 
1
  """Integration tests for the inference pipeline."""
2
+
3
  from __future__ import annotations
4
 
5
  from pathlib import Path
tests/test_models/test_positional_encoding.py CHANGED
@@ -4,7 +4,6 @@
4
  Tests for positional encoding.
5
  """
6
 
7
-
8
  import matplotlib
9
  import torch
10
 
 
4
  Tests for positional encoding.
5
  """
6
 
 
7
  import matplotlib
8
  import torch
9