Spaces:
Running
Running
Commit
·
c8c20f1
1
Parent(s):
0f7aa90
Fixing Model Pipeline Problem
Browse files- scripts/demo_gradio.py +181 -12
- src/inference/pipeline.py +63 -9
scripts/demo_gradio.py
CHANGED
|
@@ -3,11 +3,13 @@ Gradio Demo interface for LexiMind NLP pipeline.
|
|
| 3 |
Showcases summarization, emotion detection, and topic prediction.
|
| 4 |
"""
|
| 5 |
import json
|
|
|
|
| 6 |
import sys
|
| 7 |
from pathlib import Path
|
| 8 |
from tempfile import NamedTemporaryFile
|
| 9 |
from typing import Iterable, Sequence
|
| 10 |
from textwrap import dedent
|
|
|
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
from gradio.themes import Soft
|
|
@@ -31,6 +33,69 @@ logger = get_logger(__name__)
|
|
| 31 |
_pipeline: InferencePipeline | None = None # Global pipeline instance
|
| 32 |
_label_metadata = None # Cached label metadata
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def get_pipeline() -> InferencePipeline:
|
| 35 |
"""Lazy Loading and Caching the inference pipeline"""
|
| 36 |
global _pipeline, _label_metadata
|
|
@@ -88,30 +153,60 @@ def predict(text: str, compression: int):
|
|
| 88 |
logger.info("Generating summary with max length of %s", max_len)
|
| 89 |
|
| 90 |
summary = pipeline.summarize([text], max_length=max_len)[0]
|
| 91 |
-
|
|
|
|
| 92 |
topic = pipeline.predict_topics([text])[0]
|
| 93 |
|
| 94 |
clean_summary = summary.strip()
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
| 96 |
summary_source = clean_summary
|
| 97 |
-
summary_notice = ""
|
| 98 |
else:
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
"
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
summary_html = format_summary(text, summary_source, notice=summary_notice)
|
| 106 |
-
emotion_plot = create_emotion_plot(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
topic_output = format_topic(topic)
|
| 108 |
-
if clean_summary:
|
| 109 |
attention_fig = create_attention_heatmap(text, clean_summary, pipeline)
|
| 110 |
else:
|
| 111 |
attention_fig = render_unavailable_message(
|
| 112 |
-
"Attention heatmap unavailable because the
|
| 113 |
)
|
| 114 |
-
download_path = prepare_download(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
download_update = gr.update(value=download_path, visible=True)
|
| 116 |
|
| 117 |
return summary_html, emotion_plot, topic_output, attention_fig, download_update
|
|
@@ -212,6 +307,70 @@ def _clean_tokens(tokens: Iterable[str]) -> list[str]:
|
|
| 212 |
cleaned.append(item.strip() if item.strip() else token)
|
| 213 |
return cleaned
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
|
| 216 |
"""Build a lightweight extractive summary when the model generates nothing."""
|
| 217 |
if not text.strip():
|
|
@@ -335,6 +494,10 @@ def prepare_download(
|
|
| 335 |
summary: str,
|
| 336 |
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
|
| 337 |
topic: TopicPrediction | dict[str, float | str],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
) -> str:
|
| 339 |
"""Persist JSON payload to a temporary file and return its path for download."""
|
| 340 |
if isinstance(emotions, EmotionPrediction):
|
|
@@ -362,9 +525,15 @@ def prepare_download(
|
|
| 362 |
payload = {
|
| 363 |
"original_text": text,
|
| 364 |
"summary": summary,
|
|
|
|
|
|
|
| 365 |
"emotions": emotion_payload,
|
| 366 |
"topic": topic_payload,
|
| 367 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
|
| 369 |
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
| 370 |
temp_path = handle.name
|
|
|
|
| 3 |
Showcases summarization, emotion detection, and topic prediction.
|
| 4 |
"""
|
| 5 |
import json
|
| 6 |
+
import re
|
| 7 |
import sys
|
| 8 |
from pathlib import Path
|
| 9 |
from tempfile import NamedTemporaryFile
|
| 10 |
from typing import Iterable, Sequence
|
| 11 |
from textwrap import dedent
|
| 12 |
+
from collections import Counter
|
| 13 |
|
| 14 |
import gradio as gr
|
| 15 |
from gradio.themes import Soft
|
|
|
|
| 33 |
_pipeline: InferencePipeline | None = None # Global pipeline instance
|
| 34 |
_label_metadata = None # Cached label metadata
|
| 35 |
|
| 36 |
+
STOPWORDS = {
|
| 37 |
+
"the",
|
| 38 |
+
"is",
|
| 39 |
+
"a",
|
| 40 |
+
"an",
|
| 41 |
+
"to",
|
| 42 |
+
"of",
|
| 43 |
+
"and",
|
| 44 |
+
"in",
|
| 45 |
+
"it",
|
| 46 |
+
"that",
|
| 47 |
+
"for",
|
| 48 |
+
"on",
|
| 49 |
+
"with",
|
| 50 |
+
"as",
|
| 51 |
+
"by",
|
| 52 |
+
"be",
|
| 53 |
+
"are",
|
| 54 |
+
"was",
|
| 55 |
+
"were",
|
| 56 |
+
"this",
|
| 57 |
+
"which",
|
| 58 |
+
"at",
|
| 59 |
+
"or",
|
| 60 |
+
"from",
|
| 61 |
+
"but",
|
| 62 |
+
"has",
|
| 63 |
+
"have",
|
| 64 |
+
"had",
|
| 65 |
+
"can",
|
| 66 |
+
"will",
|
| 67 |
+
"would",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
EMOTION_THRESHOLDS = {
|
| 71 |
+
"anger": 0.6,
|
| 72 |
+
"fear": 0.85,
|
| 73 |
+
"joy": 0.6,
|
| 74 |
+
"love": 0.25,
|
| 75 |
+
"sadness": 0.3,
|
| 76 |
+
"surprise": 0.55,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
EMOTION_KEYWORDS = {
|
| 80 |
+
"love": {
|
| 81 |
+
"love",
|
| 82 |
+
"loved",
|
| 83 |
+
"loving",
|
| 84 |
+
"beloved",
|
| 85 |
+
"romance",
|
| 86 |
+
"romantic",
|
| 87 |
+
"affection",
|
| 88 |
+
"passion",
|
| 89 |
+
"sweetheart",
|
| 90 |
+
"valentine",
|
| 91 |
+
"dear",
|
| 92 |
+
"cherish",
|
| 93 |
+
"ador",
|
| 94 |
+
"marriage",
|
| 95 |
+
"wedding",
|
| 96 |
+
}
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
def get_pipeline() -> InferencePipeline:
|
| 100 |
"""Lazy Loading and Caching the inference pipeline"""
|
| 101 |
global _pipeline, _label_metadata
|
|
|
|
| 153 |
logger.info("Generating summary with max length of %s", max_len)
|
| 154 |
|
| 155 |
summary = pipeline.summarize([text], max_length=max_len)[0]
|
| 156 |
+
raw_emotion_pairs = extract_emotion_pairs(pipeline.predict_emotions([text], threshold=0.0)[0])
|
| 157 |
+
filtered_emotions = filter_emotions(raw_emotion_pairs, text)
|
| 158 |
topic = pipeline.predict_topics([text])[0]
|
| 159 |
|
| 160 |
clean_summary = summary.strip()
|
| 161 |
+
summary_notice = ""
|
| 162 |
+
fallback_summary: str | None = None
|
| 163 |
+
|
| 164 |
+
if clean_summary and summary_is_plausible(clean_summary, text):
|
| 165 |
summary_source = clean_summary
|
|
|
|
| 166 |
else:
|
| 167 |
+
fallback_summary = generate_fallback_summary(text)
|
| 168 |
+
summary_source = fallback_summary
|
| 169 |
+
if clean_summary:
|
| 170 |
+
logger.info("Neural summary flagged as low-overlap; showing extractive fallback instead")
|
| 171 |
+
summary_notice = dedent(
|
| 172 |
+
f"""
|
| 173 |
+
<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> The neural summary looked off-topic, so an extractive fallback is shown above.</p>
|
| 174 |
+
<details style=\"margin-top: 4px;\">
|
| 175 |
+
<summary style=\"color: #b45309; cursor: pointer;\">View the original neural summary</summary>
|
| 176 |
+
<p style=\"margin-top: 8px; background-color: #fff7ed; padding: 10px; border-radius: 4px; color: #7c2d12; white-space: pre-wrap;\">
|
| 177 |
+
{clean_summary}
|
| 178 |
+
</p>
|
| 179 |
+
</details>
|
| 180 |
+
"""
|
| 181 |
+
).strip()
|
| 182 |
+
else:
|
| 183 |
+
summary_notice = (
|
| 184 |
+
"<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> "
|
| 185 |
+
"The model did not produce a summary, so an extractive fallback is shown instead.</p>"
|
| 186 |
+
)
|
| 187 |
|
| 188 |
summary_html = format_summary(text, summary_source, notice=summary_notice)
|
| 189 |
+
emotion_plot = create_emotion_plot(filtered_emotions)
|
| 190 |
+
if emotion_plot is None:
|
| 191 |
+
emotion_plot = render_unavailable_message(
|
| 192 |
+
"No emotion met the confidence threshold."
|
| 193 |
+
)
|
| 194 |
topic_output = format_topic(topic)
|
| 195 |
+
if clean_summary and fallback_summary is None:
|
| 196 |
attention_fig = create_attention_heatmap(text, clean_summary, pipeline)
|
| 197 |
else:
|
| 198 |
attention_fig = render_unavailable_message(
|
| 199 |
+
"Attention heatmap unavailable because the neural summary was empty or flagged as unreliable."
|
| 200 |
)
|
| 201 |
+
download_path = prepare_download(
|
| 202 |
+
text,
|
| 203 |
+
summary_source,
|
| 204 |
+
filtered_emotions,
|
| 205 |
+
topic,
|
| 206 |
+
neural_summary=clean_summary or None,
|
| 207 |
+
fallback_summary=fallback_summary,
|
| 208 |
+
raw_emotions=raw_emotion_pairs,
|
| 209 |
+
)
|
| 210 |
download_update = gr.update(value=download_path, visible=True)
|
| 211 |
|
| 212 |
return summary_html, emotion_plot, topic_output, attention_fig, download_update
|
|
|
|
| 307 |
cleaned.append(item.strip() if item.strip() else token)
|
| 308 |
return cleaned
|
| 309 |
|
| 310 |
+
def extract_emotion_pairs(
|
| 311 |
+
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
|
| 312 |
+
) -> list[tuple[str, float]]:
|
| 313 |
+
if isinstance(emotions, EmotionPrediction):
|
| 314 |
+
return list(zip(map(str, emotions.labels), map(float, emotions.scores)))
|
| 315 |
+
labels = emotions.get("labels", [])
|
| 316 |
+
scores = emotions.get("scores", [])
|
| 317 |
+
return [(str(label), float(score)) for label, score in zip(labels, scores)]
|
| 318 |
+
|
| 319 |
+
def filter_emotions(pairs: list[tuple[str, float]], text: str) -> EmotionPrediction:
|
| 320 |
+
filtered: list[tuple[str, float]] = []
|
| 321 |
+
lowered_text = text.lower()
|
| 322 |
+
|
| 323 |
+
for label, score in pairs:
|
| 324 |
+
threshold = EMOTION_THRESHOLDS.get(label, 0.5)
|
| 325 |
+
if score < threshold:
|
| 326 |
+
continue
|
| 327 |
+
|
| 328 |
+
if label == "love":
|
| 329 |
+
keywords = EMOTION_KEYWORDS.get("love", set())
|
| 330 |
+
if score < 0.6 and not any(keyword in lowered_text for keyword in keywords):
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
filtered.append((label, score))
|
| 334 |
+
|
| 335 |
+
if filtered:
|
| 336 |
+
labels, scores = zip(*filtered)
|
| 337 |
+
return EmotionPrediction(labels=list(labels), scores=list(scores))
|
| 338 |
+
|
| 339 |
+
return EmotionPrediction(labels=[], scores=[])
|
| 340 |
+
|
| 341 |
+
def summary_is_plausible(
|
| 342 |
+
summary: str,
|
| 343 |
+
original: str,
|
| 344 |
+
*,
|
| 345 |
+
min_overlap: float = 0.2,
|
| 346 |
+
min_unique_ratio: float = 0.3,
|
| 347 |
+
max_repeat_ratio: float = 0.6,
|
| 348 |
+
) -> bool:
|
| 349 |
+
"""Heuristic filter to catch off-topic or repetitive neural summaries."""
|
| 350 |
+
|
| 351 |
+
summary_tokens = re.findall(r"\w+", summary.lower())
|
| 352 |
+
if not summary_tokens:
|
| 353 |
+
return False
|
| 354 |
+
|
| 355 |
+
summary_content = [token for token in summary_tokens if token not in STOPWORDS]
|
| 356 |
+
if not summary_content:
|
| 357 |
+
return False
|
| 358 |
+
|
| 359 |
+
original_vocab = {token for token in re.findall(r"\w+", original.lower()) if token not in STOPWORDS}
|
| 360 |
+
overlap = sum(1 for token in summary_content if token in original_vocab)
|
| 361 |
+
overlap_ratio = overlap / max(1, len(summary_content))
|
| 362 |
+
if overlap_ratio < min_overlap:
|
| 363 |
+
return False
|
| 364 |
+
|
| 365 |
+
token_counts = Counter(summary_content)
|
| 366 |
+
most_common_ratio = token_counts.most_common(1)[0][1] / len(summary_content)
|
| 367 |
+
unique_ratio = len(token_counts) / len(summary_content)
|
| 368 |
+
if unique_ratio < min_unique_ratio:
|
| 369 |
+
return False
|
| 370 |
+
if most_common_ratio > max_repeat_ratio:
|
| 371 |
+
return False
|
| 372 |
+
return True
|
| 373 |
+
|
| 374 |
def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
|
| 375 |
"""Build a lightweight extractive summary when the model generates nothing."""
|
| 376 |
if not text.strip():
|
|
|
|
| 494 |
summary: str,
|
| 495 |
emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
|
| 496 |
topic: TopicPrediction | dict[str, float | str],
|
| 497 |
+
*,
|
| 498 |
+
neural_summary: str | None = None,
|
| 499 |
+
fallback_summary: str | None = None,
|
| 500 |
+
raw_emotions: Sequence[tuple[str, float]] | None = None,
|
| 501 |
) -> str:
|
| 502 |
"""Persist JSON payload to a temporary file and return its path for download."""
|
| 503 |
if isinstance(emotions, EmotionPrediction):
|
|
|
|
| 525 |
payload = {
|
| 526 |
"original_text": text,
|
| 527 |
"summary": summary,
|
| 528 |
+
"neural_summary": neural_summary,
|
| 529 |
+
"fallback_summary": fallback_summary,
|
| 530 |
"emotions": emotion_payload,
|
| 531 |
"topic": topic_payload,
|
| 532 |
}
|
| 533 |
+
if raw_emotions is not None:
|
| 534 |
+
payload["raw_emotions"] = [
|
| 535 |
+
{"label": label, "score": float(score)} for label, score in raw_emotions
|
| 536 |
+
]
|
| 537 |
with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
|
| 538 |
json.dump(payload, handle, ensure_ascii=False, indent=2)
|
| 539 |
temp_path = handle.name
|
src/inference/pipeline.py
CHANGED
|
@@ -75,15 +75,13 @@ class InferencePipeline:
|
|
| 75 |
with torch.inference_mode():
|
| 76 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
-
generated = self.
|
| 79 |
-
memory=memory,
|
| 80 |
-
max_len=max_len,
|
| 81 |
-
start_token_id=self.tokenizer.bos_token_id,
|
| 82 |
-
end_token_id=self.tokenizer.eos_token_id,
|
| 83 |
-
device=self.device,
|
| 84 |
-
)
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
def predict_emotions(
|
| 89 |
self,
|
|
@@ -98,7 +96,7 @@ class InferencePipeline:
|
|
| 98 |
|
| 99 |
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 100 |
model_inputs = self._batch_to_model_inputs(batch)
|
| 101 |
-
decision_threshold =
|
| 102 |
|
| 103 |
with torch.inference_mode():
|
| 104 |
logits = self.model.forward("emotion", model_inputs)
|
|
@@ -148,6 +146,62 @@ class InferencePipeline:
|
|
| 148 |
"topic": self.predict_topics(text_list),
|
| 149 |
}
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
def _batch_to_device(self, batch: Batch) -> Batch:
|
| 152 |
tensor_updates: dict[str, torch.Tensor] = {}
|
| 153 |
for item in fields(batch):
|
|
|
|
| 75 |
with torch.inference_mode():
|
| 76 |
encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
|
| 77 |
memory = self.model.encoder(src_ids, mask=encoder_mask)
|
| 78 |
+
generated = self._constrained_greedy_decode(memory, max_len, memory_mask=src_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
trimmed_sequences: List[List[int]] = []
|
| 81 |
+
for row in generated.cpu().tolist():
|
| 82 |
+
trimmed_sequences.append(self._trim_special_tokens(row))
|
| 83 |
+
|
| 84 |
+
return self.tokenizer.decode_batch(trimmed_sequences)
|
| 85 |
|
| 86 |
def predict_emotions(
|
| 87 |
self,
|
|
|
|
| 96 |
|
| 97 |
batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
|
| 98 |
model_inputs = self._batch_to_model_inputs(batch)
|
| 99 |
+
decision_threshold = self.config.emotion_threshold if threshold is None else float(threshold)
|
| 100 |
|
| 101 |
with torch.inference_mode():
|
| 102 |
logits = self.model.forward("emotion", model_inputs)
|
|
|
|
| 146 |
"topic": self.predict_topics(text_list),
|
| 147 |
}
|
| 148 |
|
| 149 |
+
def _constrained_greedy_decode(
|
| 150 |
+
self,
|
| 151 |
+
memory: torch.Tensor,
|
| 152 |
+
max_len: int,
|
| 153 |
+
*,
|
| 154 |
+
memory_mask: torch.Tensor | None = None,
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
"""Run greedy decoding while banning BOS/PAD tokens from the generated sequence."""
|
| 157 |
+
|
| 158 |
+
device = memory.device
|
| 159 |
+
batch_size = memory.size(0)
|
| 160 |
+
bos = self.tokenizer.bos_token_id
|
| 161 |
+
pad = getattr(self.tokenizer, "pad_token_id", None)
|
| 162 |
+
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 163 |
+
|
| 164 |
+
generated = torch.full((batch_size, 1), bos, dtype=torch.long, device=device)
|
| 165 |
+
expanded_memory_mask = None
|
| 166 |
+
if memory_mask is not None:
|
| 167 |
+
expanded_memory_mask = memory_mask.to(device=device, dtype=torch.bool)
|
| 168 |
+
|
| 169 |
+
for _ in range(max(1, max_len) - 1):
|
| 170 |
+
decoder_out = self.model.decoder(generated, memory, memory_mask=expanded_memory_mask)
|
| 171 |
+
logits = decoder_out if isinstance(decoder_out, torch.Tensor) else decoder_out[0]
|
| 172 |
+
|
| 173 |
+
step_logits = logits[:, -1, :].clone()
|
| 174 |
+
if bos is not None and bos < step_logits.size(-1):
|
| 175 |
+
step_logits[:, bos] = float("-inf")
|
| 176 |
+
if pad is not None and pad < step_logits.size(-1):
|
| 177 |
+
step_logits[:, pad] = float("-inf")
|
| 178 |
+
|
| 179 |
+
next_token = step_logits.argmax(dim=-1, keepdim=True)
|
| 180 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 181 |
+
|
| 182 |
+
if eos is not None and torch.all(next_token.squeeze(-1) == eos):
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
return generated
|
| 186 |
+
|
| 187 |
+
def _trim_special_tokens(self, sequence: Sequence[int]) -> List[int]:
|
| 188 |
+
"""Remove leading BOS and trailing PAD/EOS tokens from a generated sequence."""
|
| 189 |
+
|
| 190 |
+
bos = self.tokenizer.bos_token_id
|
| 191 |
+
pad = getattr(self.tokenizer, "pad_token_id", None)
|
| 192 |
+
eos = getattr(self.tokenizer, "eos_token_id", None)
|
| 193 |
+
|
| 194 |
+
trimmed: List[int] = []
|
| 195 |
+
for idx, token in enumerate(sequence):
|
| 196 |
+
if idx == 0 and bos is not None and token == bos:
|
| 197 |
+
continue
|
| 198 |
+
if pad is not None and token == pad:
|
| 199 |
+
continue
|
| 200 |
+
if eos is not None and token == eos:
|
| 201 |
+
break
|
| 202 |
+
trimmed.append(int(token))
|
| 203 |
+
return trimmed
|
| 204 |
+
|
| 205 |
def _batch_to_device(self, batch: Batch) -> Batch:
|
| 206 |
tensor_updates: dict[str, torch.Tensor] = {}
|
| 207 |
for item in fields(batch):
|