Spaces:
Running
Running
File size: 8,275 Bytes
a47e5cf 29f2de2 185b05e 69b8f98 1601799 a47e5cf 185b05e 69b8f98 a47e5cf 185b05e c0044cc 185b05e a47e5cf 185b05e 69b8f98 baf3026 076bc18 185b05e c0044cc 185b05e c0044cc 185b05e a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e a47e5cf 1601799 185b05e 7bfcb3e a47e5cf 7bfcb3e 185b05e a47e5cf f6d689c a47e5cf f6d689c a47e5cf b650f7d 7bfcb3e b5ddd7b bdb7386 b5ddd7b a47e5cf b650f7d 69b8f98 baf3026 bdb7386 b5ddd7b f6d689c 69b8f98 076bc18 69b8f98 f6d689c a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e 1601799 a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e a7d82d1 a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e a47e5cf 185b05e 1601799 185b05e a47e5cf 185b05e a47e5cf 185b05e a47e5cf f6d689c a47e5cf f6d689c 185b05e a47e5cf 185b05e a47e5cf 1601799 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
"""
Inference pipeline for LexiMind.
Unified interface for summarization, emotion detection, and topic classification
with batched processing and device management.
Author: Oliver Perrin
Date: December 2025
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, cast
import torch
import torch.nn.functional as F
from ..data.tokenization import Tokenizer
# --------------- Text Formatting ---------------
def _format_summary(text: str) -> str:
"""Clean and format generated summary text.
- Capitalize first letter
- Fix period spacing (". " not " .")
- Remove extra whitespace
- Ensure proper sentence endings
"""
if not text:
return text
# Strip and normalize whitespace
text = " ".join(text.split())
# Remove leading punctuation/special chars
text = re.sub(r"^[^A-Za-z0-9]+", "", text)
# Fix spacing around punctuation
text = re.sub(r"\s+([.!?,;:])", r"\1", text) # Remove space before punctuation
text = re.sub(
r"([.!?])([A-Za-z])", r"\1 \2", text
) # Add space after sentence-ending punctuation
# Capitalize first letter
if text:
text = text[0].upper() + text[1:]
# Capitalize after sentence-ending punctuation
text = re.sub(r"([.!?])\s+([a-z])", lambda m: m.group(1) + " " + m.group(2).upper(), text)
# Ensure ends with punctuation
if text and text[-1] not in ".!?":
text += "."
return text
# --------------- Configuration ---------------
@dataclass
class InferenceConfig:
"""Pipeline settings."""
summary_max_length: int = 128
summary_repetition_penalty: float = 1.2 # Penalize repeated tokens
summary_length_penalty: float = 1.5 # Encourage EOS token as length increases (>1 = shorter)
summary_formatting: bool = True # Apply text cleanup/formatting to generated summaries
emotion_threshold: float = 0.5
device: str | None = None
@dataclass
class EmotionPrediction:
labels: List[str]
scores: List[float]
@dataclass
class TopicPrediction:
label: str
confidence: float
# --------------- Pipeline ---------------
class InferencePipeline:
"""Multi-task inference with batched processing."""
def __init__(
self,
model: torch.nn.Module,
tokenizer: Tokenizer,
*,
emotion_labels: Sequence[str] | None = None,
topic_labels: Sequence[str] | None = None,
config: InferenceConfig | None = None,
device: torch.device | str | None = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.config = config or InferenceConfig()
# Resolve device
chosen = device or self.config.device
if chosen is None:
param = next(model.parameters(), None)
chosen = param.device if param else "cpu"
self.device = torch.device(chosen)
self.model.to(self.device)
self.model.eval()
self.emotion_labels = list(emotion_labels) if emotion_labels else None
self.topic_labels = list(topic_labels) if topic_labels else None
# --------------- Summarization ---------------
def summarize(self, texts: Sequence[str], *, max_length: int | None = None) -> List[str]:
"""Generate summaries for input texts."""
if not texts:
return []
encoded = self.tokenizer.batch_encode(list(texts))
src_ids = encoded["input_ids"].to(self.device)
src_mask = encoded["attention_mask"].to(self.device)
max_len = max_length or self.config.summary_max_length
model = cast(Any, self.model)
if not hasattr(model, "encoder") or not hasattr(model, "decoder"):
raise RuntimeError("Model must have encoder and decoder for summarization")
with torch.inference_mode():
# Encode
enc_mask = (
src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
)
memory = model.encoder(src_ids, mask=enc_mask)
# Decode with constraints to improve quality
ban_ids = [self.tokenizer.bos_token_id, self.tokenizer.pad_token_id]
unk = getattr(self.tokenizer._tokenizer, "unk_token_id", None)
if isinstance(unk, int):
ban_ids.append(unk)
generated = model.decoder.greedy_decode(
memory=memory,
max_len=max_len,
start_token_id=self.tokenizer.bos_token_id,
end_token_id=self.tokenizer.eos_token_id,
device=self.device,
min_len=10,
ban_token_ids=[i for i in ban_ids if i is not None],
no_repeat_ngram_size=3,
repetition_penalty=self.config.summary_repetition_penalty,
length_penalty=self.config.summary_length_penalty,
memory_mask=src_mask,
)
# Decode and format summaries
raw_summaries = self.tokenizer.decode_batch(generated.tolist())
if not self.config.summary_formatting:
return raw_summaries
return [_format_summary(s) for s in raw_summaries]
# --------------- Emotion ---------------
def predict_emotions(
self,
texts: Sequence[str],
*,
threshold: float | None = None,
) -> List[EmotionPrediction]:
"""Predict emotions for input texts."""
if not texts:
return []
if not self.emotion_labels:
raise RuntimeError("emotion_labels required for emotion prediction")
encoded = self.tokenizer.batch_encode(list(texts))
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
thresh = threshold or self.config.emotion_threshold
with torch.inference_mode():
logits = self.model.forward("emotion", inputs)
probs = torch.sigmoid(logits)
results = []
for row in probs.cpu():
pairs = [
(label, score)
for label, score in zip(self.emotion_labels, row.tolist(), strict=False)
if score >= thresh
]
results.append(
EmotionPrediction(
labels=[label for label, _ in pairs],
scores=[score for _, score in pairs],
)
)
return results
# --------------- Topic ---------------
def predict_topics(self, texts: Sequence[str]) -> List[TopicPrediction]:
"""Predict topic for input texts."""
if not texts:
return []
if not self.topic_labels:
raise RuntimeError("topic_labels required for topic prediction")
encoded = self.tokenizer.batch_encode(list(texts))
input_ids = encoded["input_ids"].to(self.device)
attention_mask = encoded["attention_mask"].to(self.device)
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
with torch.inference_mode():
logits = self.model.forward("topic", inputs)
probs = F.softmax(logits, dim=-1)
results = []
for row in probs.cpu():
idx = int(row.argmax().item())
results.append(
TopicPrediction(
label=self.topic_labels[idx],
confidence=row[idx].item(),
)
)
return results
# --------------- Batch Prediction ---------------
def batch_predict(self, texts: Sequence[str]) -> Dict[str, Any]:
"""Run all three tasks on input texts."""
if not self.emotion_labels or not self.topic_labels:
raise RuntimeError("Both emotion_labels and topic_labels required")
text_list = list(texts)
return {
"summaries": self.summarize(text_list),
"emotion": self.predict_emotions(text_list),
"topic": self.predict_topics(text_list),
}
# --------------- Helpers ---------------
# (helper methods removed - encoding now happens inline)
|