OliverPerrin commited on
Commit
69b8f98
·
1 Parent(s): 9becd3c

Improve summarization output quality

Browse files

- Add _format_summary() for proper capitalization and punctuation
- Add repetition_penalty (1.2) to reduce repetitive outputs
- Fix period spacing and sentence capitalization
- Remove leading special characters from generated text

outputs/evaluation_report.json CHANGED
@@ -1,44 +1,80 @@
1
  {
2
  "split": "val",
3
  "summarization": {
4
- "rouge_like": 0.35947467920968945,
5
- "bleu": 0.09027012433010549
6
  },
7
  "emotion": {
8
- "f1_macro": 0.9455000162124634
9
  },
10
  "topic": {
11
- "accuracy": 0.94175,
12
  "classification_report": {
13
- "Business": {
14
- "precision": 0.9319045973038369,
15
- "recall": 0.8986666666666666,
16
- "f1-score": 0.9149838791786866,
17
- "support": 3000
18
- },
19
- "Sci/Tech": {
20
- "precision": 0.9055627425614489,
21
- "recall": 0.9333333333333333,
22
- "f1-score": 0.9192383453709784,
23
- "support": 3000
24
  },
25
- "Sports": {
26
- "precision": 0.9856475300400535,
27
- "recall": 0.9843333333333333,
28
- "f1-score": 0.9849899933288859,
29
- "support": 3000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  },
31
- "World": {
32
- "precision": 0.9446836700894335,
33
- "recall": 0.9506666666666667,
34
- "f1-score": 0.9476657252035222,
35
- "support": 3000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  },
37
  "macro avg": {
38
- "precision": 0.9419496349986932,
39
- "recall": 0.94175,
40
- "f1-score": 0.9417194857705183,
41
- "support": 12000
42
  }
43
  }
44
  }
 
1
  {
2
  "split": "val",
3
  "summarization": {
4
+ "rouge_like": 0.13567121660564777,
5
+ "bleu": 0.014673668103097205
6
  },
7
  "emotion": {
8
+ "f1_macro": 0.1939181685447693
9
  },
10
  "topic": {
11
+ "accuracy": 0.741687849517031,
12
  "classification_report": {
13
+ "Business & Finance": {
14
+ "precision": 0.6439114391143912,
15
+ "recall": 0.527190332326284,
16
+ "f1-score": 0.579734219269103,
17
+ "support": 1986
 
 
 
 
 
 
18
  },
19
+ "Computers & Internet": {
20
+ "precision": 0.8251038301799724,
21
+ "recall": 0.9044006069802731,
22
+ "f1-score": 0.862934362934363,
23
+ "support": 1977
24
+ },
25
+ "Education & Reference": {
26
+ "precision": 0.6439444076770351,
27
+ "recall": 0.49642857142857144,
28
+ "f1-score": 0.560645347162201,
29
+ "support": 1960
30
+ },
31
+ "Entertainment & Music": {
32
+ "precision": 0.7064310260186549,
33
+ "recall": 0.7360613810741689,
34
+ "f1-score": 0.7209418837675351,
35
+ "support": 1955
36
+ },
37
+ "Family & Relationships": {
38
+ "precision": 0.7182971014492754,
39
+ "recall": 0.8071246819338422,
40
+ "f1-score": 0.7601246105919003,
41
+ "support": 1965
42
+ },
43
+ "Health": {
44
+ "precision": 0.7610579115367077,
45
+ "recall": 0.8489318413021363,
46
+ "f1-score": 0.8025967780716519,
47
+ "support": 1966
48
  },
49
+ "Politics & Government": {
50
+ "precision": 0.7711132437619962,
51
+ "recall": 0.8173957273652085,
52
+ "f1-score": 0.7935802469135802,
53
+ "support": 1966
54
+ },
55
+ "Science & Mathematics": {
56
+ "precision": 0.7456647398843931,
57
+ "recall": 0.7885888945491595,
58
+ "f1-score": 0.7665263679128497,
59
+ "support": 1963
60
+ },
61
+ "Society & Culture": {
62
+ "precision": 0.6496559633027523,
63
+ "recall": 0.5783563042368556,
64
+ "f1-score": 0.6119362678908993,
65
+ "support": 1959
66
+ },
67
+ "Sports": {
68
+ "precision": 0.8888339920948617,
69
+ "recall": 0.9118094272681196,
70
+ "f1-score": 0.9001751313485113,
71
+ "support": 1973
72
  },
73
  "macro avg": {
74
+ "precision": 0.735401365502004,
75
+ "recall": 0.7416287768464619,
76
+ "f1-score": 0.7359195215862595,
77
+ "support": 19670
78
  }
79
  }
80
  }
src/inference/pipeline.py CHANGED
@@ -10,6 +10,7 @@ Date: December 2025
10
 
11
  from __future__ import annotations
12
 
 
13
  from dataclasses import dataclass, fields, replace
14
  from typing import Any, Dict, List, Sequence, cast
15
 
@@ -19,6 +20,46 @@ import torch.nn.functional as F
19
  from ..data.preprocessing import Batch, TextPreprocessor
20
  from ..data.tokenization import Tokenizer
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # --------------- Configuration ---------------
23
 
24
 
@@ -27,6 +68,7 @@ class InferenceConfig:
27
  """Pipeline settings."""
28
 
29
  summary_max_length: int = 128
 
30
  emotion_threshold: float = 0.5
31
  device: str | None = None
32
 
@@ -116,10 +158,13 @@ class InferencePipeline:
116
  min_len=10,
117
  ban_token_ids=[i for i in ban_ids if i is not None],
118
  no_repeat_ngram_size=3,
 
119
  memory_mask=src_mask,
120
  )
121
 
122
- return self.tokenizer.decode_batch(generated.tolist())
 
 
123
 
124
  # --------------- Emotion ---------------
125
 
 
10
 
11
  from __future__ import annotations
12
 
13
+ import re
14
  from dataclasses import dataclass, fields, replace
15
  from typing import Any, Dict, List, Sequence, cast
16
 
 
20
  from ..data.preprocessing import Batch, TextPreprocessor
21
  from ..data.tokenization import Tokenizer
22
 
23
+ # --------------- Text Formatting ---------------
24
+
25
+
26
+ def _format_summary(text: str) -> str:
27
+ """Clean and format generated summary text.
28
+
29
+ - Capitalize first letter
30
+ - Fix period spacing (". " not " .")
31
+ - Remove extra whitespace
32
+ - Ensure proper sentence endings
33
+ """
34
+ if not text:
35
+ return text
36
+
37
+ # Strip and normalize whitespace
38
+ text = " ".join(text.split())
39
+
40
+ # Remove leading punctuation/special chars
41
+ text = re.sub(r"^[^A-Za-z0-9]+", "", text)
42
+
43
+ # Fix spacing around punctuation
44
+ text = re.sub(r"\s+([.!?,;:])", r"\1", text) # Remove space before punctuation
45
+ text = re.sub(
46
+ r"([.!?])([A-Za-z])", r"\1 \2", text
47
+ ) # Add space after sentence-ending punctuation
48
+
49
+ # Capitalize first letter
50
+ if text:
51
+ text = text[0].upper() + text[1:]
52
+
53
+ # Capitalize after sentence-ending punctuation
54
+ text = re.sub(r"([.!?])\s+([a-z])", lambda m: m.group(1) + " " + m.group(2).upper(), text)
55
+
56
+ # Ensure ends with punctuation
57
+ if text and text[-1] not in ".!?":
58
+ text += "."
59
+
60
+ return text
61
+
62
+
63
  # --------------- Configuration ---------------
64
 
65
 
 
68
  """Pipeline settings."""
69
 
70
  summary_max_length: int = 128
71
+ summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
72
  emotion_threshold: float = 0.5
73
  device: str | None = None
74
 
 
158
  min_len=10,
159
  ban_token_ids=[i for i in ban_ids if i is not None],
160
  no_repeat_ngram_size=3,
161
+ repetition_penalty=self.config.summary_repetition_penalty,
162
  memory_mask=src_mask,
163
  )
164
 
165
+ # Decode and format summaries
166
+ raw_summaries = self.tokenizer.decode_batch(generated.tolist())
167
+ return [_format_summary(s) for s in raw_summaries]
168
 
169
  # --------------- Emotion ---------------
170