OliverPerrin commited on
Commit
c8c20f1
·
1 Parent(s): 0f7aa90

Fixing Model Pipeline Problem

Browse files
Files changed (2) hide show
  1. scripts/demo_gradio.py +181 -12
  2. src/inference/pipeline.py +63 -9
scripts/demo_gradio.py CHANGED
@@ -3,11 +3,13 @@ Gradio Demo interface for LexiMind NLP pipeline.
3
  Showcases summarization, emotion detection, and topic prediction.
4
  """
5
  import json
 
6
  import sys
7
  from pathlib import Path
8
  from tempfile import NamedTemporaryFile
9
  from typing import Iterable, Sequence
10
  from textwrap import dedent
 
11
 
12
  import gradio as gr
13
  from gradio.themes import Soft
@@ -31,6 +33,69 @@ logger = get_logger(__name__)
31
  _pipeline: InferencePipeline | None = None # Global pipeline instance
32
  _label_metadata = None # Cached label metadata
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_pipeline() -> InferencePipeline:
35
  """Lazy Loading and Caching the inference pipeline"""
36
  global _pipeline, _label_metadata
@@ -88,30 +153,60 @@ def predict(text: str, compression: int):
88
  logger.info("Generating summary with max length of %s", max_len)
89
 
90
  summary = pipeline.summarize([text], max_length=max_len)[0]
91
- emotions = pipeline.predict_emotions([text], threshold=0.0)[0]
 
92
  topic = pipeline.predict_topics([text])[0]
93
 
94
  clean_summary = summary.strip()
95
- if clean_summary:
 
 
 
96
  summary_source = clean_summary
97
- summary_notice = ""
98
  else:
99
- summary_source = generate_fallback_summary(text)
100
- summary_notice = (
101
- "<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> "
102
- "The model did not produce a summary, so a fallback extractive summary is shown instead.</p>"
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  summary_html = format_summary(text, summary_source, notice=summary_notice)
106
- emotion_plot = create_emotion_plot(emotions)
 
 
 
 
107
  topic_output = format_topic(topic)
108
- if clean_summary:
109
  attention_fig = create_attention_heatmap(text, clean_summary, pipeline)
110
  else:
111
  attention_fig = render_unavailable_message(
112
- "Attention heatmap unavailable because the model produced an empty summary."
113
  )
114
- download_path = prepare_download(text, summary_source, emotions, topic)
 
 
 
 
 
 
 
 
115
  download_update = gr.update(value=download_path, visible=True)
116
 
117
  return summary_html, emotion_plot, topic_output, attention_fig, download_update
@@ -212,6 +307,70 @@ def _clean_tokens(tokens: Iterable[str]) -> list[str]:
212
  cleaned.append(item.strip() if item.strip() else token)
213
  return cleaned
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
216
  """Build a lightweight extractive summary when the model generates nothing."""
217
  if not text.strip():
@@ -335,6 +494,10 @@ def prepare_download(
335
  summary: str,
336
  emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
337
  topic: TopicPrediction | dict[str, float | str],
 
 
 
 
338
  ) -> str:
339
  """Persist JSON payload to a temporary file and return its path for download."""
340
  if isinstance(emotions, EmotionPrediction):
@@ -362,9 +525,15 @@ def prepare_download(
362
  payload = {
363
  "original_text": text,
364
  "summary": summary,
 
 
365
  "emotions": emotion_payload,
366
  "topic": topic_payload,
367
  }
 
 
 
 
368
  with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
369
  json.dump(payload, handle, ensure_ascii=False, indent=2)
370
  temp_path = handle.name
 
3
  Showcases summarization, emotion detection, and topic prediction.
4
  """
5
  import json
6
+ import re
7
  import sys
8
  from pathlib import Path
9
  from tempfile import NamedTemporaryFile
10
  from typing import Iterable, Sequence
11
  from textwrap import dedent
12
+ from collections import Counter
13
 
14
  import gradio as gr
15
  from gradio.themes import Soft
 
33
  _pipeline: InferencePipeline | None = None # Global pipeline instance
34
  _label_metadata = None # Cached label metadata
35
 
36
+ STOPWORDS = {
37
+ "the",
38
+ "is",
39
+ "a",
40
+ "an",
41
+ "to",
42
+ "of",
43
+ "and",
44
+ "in",
45
+ "it",
46
+ "that",
47
+ "for",
48
+ "on",
49
+ "with",
50
+ "as",
51
+ "by",
52
+ "be",
53
+ "are",
54
+ "was",
55
+ "were",
56
+ "this",
57
+ "which",
58
+ "at",
59
+ "or",
60
+ "from",
61
+ "but",
62
+ "has",
63
+ "have",
64
+ "had",
65
+ "can",
66
+ "will",
67
+ "would",
68
+ }
69
+
70
+ EMOTION_THRESHOLDS = {
71
+ "anger": 0.6,
72
+ "fear": 0.85,
73
+ "joy": 0.6,
74
+ "love": 0.25,
75
+ "sadness": 0.3,
76
+ "surprise": 0.55,
77
+ }
78
+
79
+ EMOTION_KEYWORDS = {
80
+ "love": {
81
+ "love",
82
+ "loved",
83
+ "loving",
84
+ "beloved",
85
+ "romance",
86
+ "romantic",
87
+ "affection",
88
+ "passion",
89
+ "sweetheart",
90
+ "valentine",
91
+ "dear",
92
+ "cherish",
93
+ "ador",
94
+ "marriage",
95
+ "wedding",
96
+ }
97
+ }
98
+
99
  def get_pipeline() -> InferencePipeline:
100
  """Lazy Loading and Caching the inference pipeline"""
101
  global _pipeline, _label_metadata
 
153
  logger.info("Generating summary with max length of %s", max_len)
154
 
155
  summary = pipeline.summarize([text], max_length=max_len)[0]
156
+ raw_emotion_pairs = extract_emotion_pairs(pipeline.predict_emotions([text], threshold=0.0)[0])
157
+ filtered_emotions = filter_emotions(raw_emotion_pairs, text)
158
  topic = pipeline.predict_topics([text])[0]
159
 
160
  clean_summary = summary.strip()
161
+ summary_notice = ""
162
+ fallback_summary: str | None = None
163
+
164
+ if clean_summary and summary_is_plausible(clean_summary, text):
165
  summary_source = clean_summary
 
166
  else:
167
+ fallback_summary = generate_fallback_summary(text)
168
+ summary_source = fallback_summary
169
+ if clean_summary:
170
+ logger.info("Neural summary flagged as low-overlap; showing extractive fallback instead")
171
+ summary_notice = dedent(
172
+ f"""
173
+ <p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> The neural summary looked off-topic, so an extractive fallback is shown above.</p>
174
+ <details style=\"margin-top: 4px;\">
175
+ <summary style=\"color: #b45309; cursor: pointer;\">View the original neural summary</summary>
176
+ <p style=\"margin-top: 8px; background-color: #fff7ed; padding: 10px; border-radius: 4px; color: #7c2d12; white-space: pre-wrap;\">
177
+ {clean_summary}
178
+ </p>
179
+ </details>
180
+ """
181
+ ).strip()
182
+ else:
183
+ summary_notice = (
184
+ "<p style=\"color: #b45309; margin-top: 8px;\"><strong>Heads-up:</strong> "
185
+ "The model did not produce a summary, so an extractive fallback is shown instead.</p>"
186
+ )
187
 
188
  summary_html = format_summary(text, summary_source, notice=summary_notice)
189
+ emotion_plot = create_emotion_plot(filtered_emotions)
190
+ if emotion_plot is None:
191
+ emotion_plot = render_unavailable_message(
192
+ "No emotion met the confidence threshold."
193
+ )
194
  topic_output = format_topic(topic)
195
+ if clean_summary and fallback_summary is None:
196
  attention_fig = create_attention_heatmap(text, clean_summary, pipeline)
197
  else:
198
  attention_fig = render_unavailable_message(
199
+ "Attention heatmap unavailable because the neural summary was empty or flagged as unreliable."
200
  )
201
+ download_path = prepare_download(
202
+ text,
203
+ summary_source,
204
+ filtered_emotions,
205
+ topic,
206
+ neural_summary=clean_summary or None,
207
+ fallback_summary=fallback_summary,
208
+ raw_emotions=raw_emotion_pairs,
209
+ )
210
  download_update = gr.update(value=download_path, visible=True)
211
 
212
  return summary_html, emotion_plot, topic_output, attention_fig, download_update
 
307
  cleaned.append(item.strip() if item.strip() else token)
308
  return cleaned
309
 
310
+ def extract_emotion_pairs(
311
+ emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]]
312
+ ) -> list[tuple[str, float]]:
313
+ if isinstance(emotions, EmotionPrediction):
314
+ return list(zip(map(str, emotions.labels), map(float, emotions.scores)))
315
+ labels = emotions.get("labels", [])
316
+ scores = emotions.get("scores", [])
317
+ return [(str(label), float(score)) for label, score in zip(labels, scores)]
318
+
319
+ def filter_emotions(pairs: list[tuple[str, float]], text: str) -> EmotionPrediction:
320
+ filtered: list[tuple[str, float]] = []
321
+ lowered_text = text.lower()
322
+
323
+ for label, score in pairs:
324
+ threshold = EMOTION_THRESHOLDS.get(label, 0.5)
325
+ if score < threshold:
326
+ continue
327
+
328
+ if label == "love":
329
+ keywords = EMOTION_KEYWORDS.get("love", set())
330
+ if score < 0.6 and not any(keyword in lowered_text for keyword in keywords):
331
+ continue
332
+
333
+ filtered.append((label, score))
334
+
335
+ if filtered:
336
+ labels, scores = zip(*filtered)
337
+ return EmotionPrediction(labels=list(labels), scores=list(scores))
338
+
339
+ return EmotionPrediction(labels=[], scores=[])
340
+
341
+ def summary_is_plausible(
342
+ summary: str,
343
+ original: str,
344
+ *,
345
+ min_overlap: float = 0.2,
346
+ min_unique_ratio: float = 0.3,
347
+ max_repeat_ratio: float = 0.6,
348
+ ) -> bool:
349
+ """Heuristic filter to catch off-topic or repetitive neural summaries."""
350
+
351
+ summary_tokens = re.findall(r"\w+", summary.lower())
352
+ if not summary_tokens:
353
+ return False
354
+
355
+ summary_content = [token for token in summary_tokens if token not in STOPWORDS]
356
+ if not summary_content:
357
+ return False
358
+
359
+ original_vocab = {token for token in re.findall(r"\w+", original.lower()) if token not in STOPWORDS}
360
+ overlap = sum(1 for token in summary_content if token in original_vocab)
361
+ overlap_ratio = overlap / max(1, len(summary_content))
362
+ if overlap_ratio < min_overlap:
363
+ return False
364
+
365
+ token_counts = Counter(summary_content)
366
+ most_common_ratio = token_counts.most_common(1)[0][1] / len(summary_content)
367
+ unique_ratio = len(token_counts) / len(summary_content)
368
+ if unique_ratio < min_unique_ratio:
369
+ return False
370
+ if most_common_ratio > max_repeat_ratio:
371
+ return False
372
+ return True
373
+
374
  def generate_fallback_summary(text: str, max_chars: int = 320) -> str:
375
  """Build a lightweight extractive summary when the model generates nothing."""
376
  if not text.strip():
 
494
  summary: str,
495
  emotions: EmotionPrediction | dict[str, Sequence[float] | Sequence[str]],
496
  topic: TopicPrediction | dict[str, float | str],
497
+ *,
498
+ neural_summary: str | None = None,
499
+ fallback_summary: str | None = None,
500
+ raw_emotions: Sequence[tuple[str, float]] | None = None,
501
  ) -> str:
502
  """Persist JSON payload to a temporary file and return its path for download."""
503
  if isinstance(emotions, EmotionPrediction):
 
525
  payload = {
526
  "original_text": text,
527
  "summary": summary,
528
+ "neural_summary": neural_summary,
529
+ "fallback_summary": fallback_summary,
530
  "emotions": emotion_payload,
531
  "topic": topic_payload,
532
  }
533
+ if raw_emotions is not None:
534
+ payload["raw_emotions"] = [
535
+ {"label": label, "score": float(score)} for label, score in raw_emotions
536
+ ]
537
  with NamedTemporaryFile("w", delete=False, suffix=".json", encoding="utf-8") as handle:
538
  json.dump(payload, handle, ensure_ascii=False, indent=2)
539
  temp_path = handle.name
src/inference/pipeline.py CHANGED
@@ -75,15 +75,13 @@ class InferencePipeline:
75
  with torch.inference_mode():
76
  encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
- generated = self.model.decoder.greedy_decode(
79
- memory=memory,
80
- max_len=max_len,
81
- start_token_id=self.tokenizer.bos_token_id,
82
- end_token_id=self.tokenizer.eos_token_id,
83
- device=self.device,
84
- )
85
 
86
- return self.tokenizer.decode_batch(generated.tolist())
 
 
 
 
87
 
88
  def predict_emotions(
89
  self,
@@ -98,7 +96,7 @@ class InferencePipeline:
98
 
99
  batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
100
  model_inputs = self._batch_to_model_inputs(batch)
101
- decision_threshold = threshold or self.config.emotion_threshold
102
 
103
  with torch.inference_mode():
104
  logits = self.model.forward("emotion", model_inputs)
@@ -148,6 +146,62 @@ class InferencePipeline:
148
  "topic": self.predict_topics(text_list),
149
  }
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  def _batch_to_device(self, batch: Batch) -> Batch:
152
  tensor_updates: dict[str, torch.Tensor] = {}
153
  for item in fields(batch):
 
75
  with torch.inference_mode():
76
  encoder_mask = src_mask.unsqueeze(1) & src_mask.unsqueeze(2) if src_mask is not None else None
77
  memory = self.model.encoder(src_ids, mask=encoder_mask)
78
+ generated = self._constrained_greedy_decode(memory, max_len, memory_mask=src_mask)
 
 
 
 
 
 
79
 
80
+ trimmed_sequences: List[List[int]] = []
81
+ for row in generated.cpu().tolist():
82
+ trimmed_sequences.append(self._trim_special_tokens(row))
83
+
84
+ return self.tokenizer.decode_batch(trimmed_sequences)
85
 
86
  def predict_emotions(
87
  self,
 
96
 
97
  batch = self._batch_to_device(self.preprocessor.batch_encode(texts))
98
  model_inputs = self._batch_to_model_inputs(batch)
99
+ decision_threshold = self.config.emotion_threshold if threshold is None else float(threshold)
100
 
101
  with torch.inference_mode():
102
  logits = self.model.forward("emotion", model_inputs)
 
146
  "topic": self.predict_topics(text_list),
147
  }
148
 
149
+ def _constrained_greedy_decode(
150
+ self,
151
+ memory: torch.Tensor,
152
+ max_len: int,
153
+ *,
154
+ memory_mask: torch.Tensor | None = None,
155
+ ) -> torch.Tensor:
156
+ """Run greedy decoding while banning BOS/PAD tokens from the generated sequence."""
157
+
158
+ device = memory.device
159
+ batch_size = memory.size(0)
160
+ bos = self.tokenizer.bos_token_id
161
+ pad = getattr(self.tokenizer, "pad_token_id", None)
162
+ eos = getattr(self.tokenizer, "eos_token_id", None)
163
+
164
+ generated = torch.full((batch_size, 1), bos, dtype=torch.long, device=device)
165
+ expanded_memory_mask = None
166
+ if memory_mask is not None:
167
+ expanded_memory_mask = memory_mask.to(device=device, dtype=torch.bool)
168
+
169
+ for _ in range(max(1, max_len) - 1):
170
+ decoder_out = self.model.decoder(generated, memory, memory_mask=expanded_memory_mask)
171
+ logits = decoder_out if isinstance(decoder_out, torch.Tensor) else decoder_out[0]
172
+
173
+ step_logits = logits[:, -1, :].clone()
174
+ if bos is not None and bos < step_logits.size(-1):
175
+ step_logits[:, bos] = float("-inf")
176
+ if pad is not None and pad < step_logits.size(-1):
177
+ step_logits[:, pad] = float("-inf")
178
+
179
+ next_token = step_logits.argmax(dim=-1, keepdim=True)
180
+ generated = torch.cat([generated, next_token], dim=1)
181
+
182
+ if eos is not None and torch.all(next_token.squeeze(-1) == eos):
183
+ break
184
+
185
+ return generated
186
+
187
+ def _trim_special_tokens(self, sequence: Sequence[int]) -> List[int]:
188
+ """Remove leading BOS and trailing PAD/EOS tokens from a generated sequence."""
189
+
190
+ bos = self.tokenizer.bos_token_id
191
+ pad = getattr(self.tokenizer, "pad_token_id", None)
192
+ eos = getattr(self.tokenizer, "eos_token_id", None)
193
+
194
+ trimmed: List[int] = []
195
+ for idx, token in enumerate(sequence):
196
+ if idx == 0 and bos is not None and token == bos:
197
+ continue
198
+ if pad is not None and token == pad:
199
+ continue
200
+ if eos is not None and token == eos:
201
+ break
202
+ trimmed.append(int(token))
203
+ return trimmed
204
+
205
  def _batch_to_device(self, batch: Batch) -> Batch:
206
  tensor_updates: dict[str, torch.Tensor] = {}
207
  for item in fields(batch):