File size: 8,275 Bytes
a47e5cf
 
 
 
 
 
 
 
 
29f2de2
185b05e
 
69b8f98
1601799
a47e5cf
185b05e
 
 
 
 
 
69b8f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a47e5cf
 
185b05e
c0044cc
185b05e
a47e5cf
185b05e
 
69b8f98
baf3026
076bc18
185b05e
 
 
 
c0044cc
185b05e
 
 
 
 
c0044cc
185b05e
 
 
 
 
a47e5cf
 
 
185b05e
a47e5cf
185b05e
 
 
 
 
 
 
 
 
 
 
 
 
 
a47e5cf
 
 
 
 
 
 
 
185b05e
 
 
a47e5cf
 
 
 
185b05e
 
a47e5cf
185b05e
 
a47e5cf
1601799
 
 
185b05e
 
7bfcb3e
a47e5cf
 
7bfcb3e
185b05e
a47e5cf
 
f6d689c
 
a47e5cf
f6d689c
a47e5cf
 
 
 
 
b650f7d
7bfcb3e
b5ddd7b
 
bdb7386
b5ddd7b
 
a47e5cf
 
b650f7d
69b8f98
baf3026
bdb7386
b5ddd7b
f6d689c
69b8f98
 
076bc18
 
69b8f98
f6d689c
a47e5cf
185b05e
 
 
 
 
 
 
a47e5cf
185b05e
 
a47e5cf
 
185b05e
1601799
 
 
 
a47e5cf
185b05e
 
a47e5cf
185b05e
 
a47e5cf
185b05e
 
 
a7d82d1
a47e5cf
185b05e
a47e5cf
 
 
 
 
 
 
 
 
185b05e
 
a47e5cf
185b05e
 
a47e5cf
 
185b05e
1601799
 
 
 
185b05e
 
a47e5cf
185b05e
 
a47e5cf
185b05e
a47e5cf
f6d689c
a47e5cf
 
 
 
f6d689c
185b05e
 
a47e5cf
 
 
 
 
 
 
185b05e
 
 
 
 
 
 
a47e5cf
1601799
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
Inference pipeline for LexiMind.

Unified interface for summarization, emotion detection, and topic classification
with batched processing and device management.

Author: Oliver Perrin
Date: December 2025
"""

from __future__ import annotations

import re
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, cast

import torch
import torch.nn.functional as F

from ..data.tokenization import Tokenizer

# --------------- Text Formatting ---------------


def _format_summary(text: str) -> str:
    """Clean and format generated summary text.

    - Capitalize first letter
    - Fix period spacing (". " not " .")
    - Remove extra whitespace
    - Ensure proper sentence endings
    """
    if not text:
        return text

    # Strip and normalize whitespace
    text = " ".join(text.split())

    # Remove leading punctuation/special chars
    text = re.sub(r"^[^A-Za-z0-9]+", "", text)

    # Fix spacing around punctuation
    text = re.sub(r"\s+([.!?,;:])", r"\1", text)  # Remove space before punctuation
    text = re.sub(
        r"([.!?])([A-Za-z])", r"\1 \2", text
    )  # Add space after sentence-ending punctuation

    # Capitalize first letter
    if text:
        text = text[0].upper() + text[1:]

    # Capitalize after sentence-ending punctuation
    text = re.sub(r"([.!?])\s+([a-z])", lambda m: m.group(1) + " " + m.group(2).upper(), text)

    # Ensure ends with punctuation
    if text and text[-1] not in ".!?":
        text += "."

    return text


# --------------- Configuration ---------------


@dataclass
class InferenceConfig:
    """Pipeline settings."""

    summary_max_length: int = 128
    summary_repetition_penalty: float = 1.2  # Penalize repeated tokens
    summary_length_penalty: float = 1.5  # Encourage EOS token as length increases (>1 = shorter)
    summary_formatting: bool = True  # Apply text cleanup/formatting to generated summaries
    emotion_threshold: float = 0.5
    device: str | None = None


@dataclass
class EmotionPrediction:
    labels: List[str]
    scores: List[float]


@dataclass
class TopicPrediction:
    label: str
    confidence: float


# --------------- Pipeline ---------------


class InferencePipeline:
    """Multi-task inference with batched processing."""

    def __init__(
        self,
        model: torch.nn.Module,
        tokenizer: Tokenizer,
        *,
        emotion_labels: Sequence[str] | None = None,
        topic_labels: Sequence[str] | None = None,
        config: InferenceConfig | None = None,
        device: torch.device | str | None = None,
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.config = config or InferenceConfig()

        # Resolve device
        chosen = device or self.config.device
        if chosen is None:
            param = next(model.parameters(), None)
            chosen = param.device if param else "cpu"
        self.device = torch.device(chosen)

        self.model.to(self.device)
        self.model.eval()

        self.emotion_labels = list(emotion_labels) if emotion_labels else None
        self.topic_labels = list(topic_labels) if topic_labels else None

    # --------------- Summarization ---------------

    def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
        """Generate summaries for input texts."""
        if not texts:
            return []

        encoded = self.tokenizer.batch_encode(list(texts))
        src_ids = encoded["input_ids"].to(self.device)
        src_mask = encoded["attention_mask"].to(self.device)
        max_len = max_length or self.config.summary_max_length

        model = cast(Any, self.model)
        if not hasattr(model, "encoder") or not hasattr(model, "decoder"):
            raise RuntimeError("Model must have encoder and decoder for summarization")

        with torch.inference_mode():
            # Encode
            enc_mask = (
                src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
            )
            memory = model.encoder(src_ids, mask=enc_mask)

            # Decode with constraints to improve quality
            ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
            unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
            if isinstance(unk, int):
                ban_ids.append(unk)

            generated = model.decoder.greedy_decode(
                memory=memory,
                max_len=max_len,
                start_token_id=self.tokenizer.bos_token_id,
                end_token_id=self.tokenizer.eos_token_id,
                device=self.device,
                min_len=10,
                ban_token_ids=[i for i in ban_ids if i is not None],
                no_repeat_ngram_size=3,
                repetition_penalty=self.config.summary_repetition_penalty,
                length_penalty=self.config.summary_length_penalty,
                memory_mask=src_mask,
            )

        # Decode and format summaries
        raw_summaries = self.tokenizer.decode_batch(generated.tolist())
        if not self.config.summary_formatting:
            return raw_summaries
        return [_format_summary(s) for s in raw_summaries]

    # --------------- Emotion ---------------

    def predict_emotions(
        self,
        texts: Sequence[str],
        *,
        threshold: float | None = None,
    ) -> List[EmotionPrediction]:
        """Predict emotions for input texts."""
        if not texts:
            return []
        if not self.emotion_labels:
            raise RuntimeError("emotion_labels required for emotion prediction")

        encoded = self.tokenizer.batch_encode(list(texts))
        input_ids = encoded["input_ids"].to(self.device)
        attention_mask = encoded["attention_mask"].to(self.device)
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
        thresh = threshold or self.config.emotion_threshold

        with torch.inference_mode():
            logits = self.model.forward("emotion", inputs)
            probs = torch.sigmoid(logits)

        results = []
        for row in probs.cpu():
            pairs = [
                (label, score)
                for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
                if score >= thresh
            ]
            results.append(
                EmotionPrediction(
                    labels=[label for label, _ in pairs],
                    scores=[score for _, score in pairs],
                )
            )
        return results

    # --------------- Topic ---------------

    def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
        """Predict topic for input texts."""
        if not texts:
            return []
        if not self.topic_labels:
            raise RuntimeError("topic_labels required for topic prediction")

        encoded = self.tokenizer.batch_encode(list(texts))
        input_ids = encoded["input_ids"].to(self.device)
        attention_mask = encoded["attention_mask"].to(self.device)
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

        with torch.inference_mode():
            logits = self.model.forward("topic", inputs)
            probs = F.softmax(logits, dim=-1)

        results = []
        for row in probs.cpu():
            idx = int(row.argmax().item())
            results.append(
                TopicPrediction(
                    label=self.topic_labels[idx],
                    confidence=row[idx].item(),
                )
            )
        return results

    # --------------- Batch Prediction ---------------

    def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]:
        """Run all three tasks on input texts."""
        if not self.emotion_labels or not self.topic_labels:
            raise RuntimeError("Both emotion_labels and topic_labels required")

        text_list = list(texts)
        return {
            "summaries": self.summarize(text_list),
            "emotion": self.predict_emotions(text_list),
            "topic": self.predict_topics(text_list),
        }

    # --------------- Helpers ---------------
    # (helper methods removed - encoding now happens inline)