EphAsad commited on
Commit
e352007
·
verified ·
1 Parent(s): c633777

Update rag/rag_generator.py

Browse files
Files changed (1) hide show
  1. rag/rag_generator.py +25 -30
rag/rag_generator.py CHANGED
@@ -29,7 +29,7 @@ from __future__ import annotations
29
  import os
30
  import re
31
  import torch
32
- from transformers import T5ForConditionalGeneration, T5Tokenizer
33
 
34
 
35
  # ------------------------------------------------------------
@@ -38,8 +38,8 @@ from transformers import T5ForConditionalGeneration, T5Tokenizer
38
 
39
  MODEL_NAME = "facebook/bart-large"
40
 
41
- _tokenizer: T5Tokenizer | None = None
42
- _model: T5ForConditionalGeneration | None = None
43
 
44
  _MAX_INPUT_TOKENS = 1020
45
  _DEFAULT_MAX_NEW_TOKENS = 256
@@ -47,15 +47,15 @@ _DEFAULT_MAX_NEW_TOKENS = 256
47
  _CONTEXT_CHAR_CAP = 2400
48
 
49
 
50
- def _get_model() -> tuple[T5Tokenizer, T5ForConditionalGeneration]:
51
  global _tokenizer, _model
52
  if _tokenizer is None or _model is None:
53
- _tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
54
- _model = T5ForConditionalGeneration.from_pretrained(
55
- MODEL_NAME,
56
- device_map="auto",
57
- torch_dtype=torch.float32,
58
- )
59
  return _tokenizer, _model
60
 
61
 
@@ -136,7 +136,7 @@ def _looks_like_echo_or_garbage(text: str) -> bool:
136
  return True
137
 
138
  if "." not in s and "match" not in low and "conflict" not in low:
139
- return True
140
 
141
  return False
142
 
@@ -147,12 +147,16 @@ def _looks_like_echo_or_garbage(text: str) -> bool:
147
 
148
  _KEY_MATCHES_HEADER_RE = re.compile(r"^\s*KEY MATCHES\s*:\s*$", re.IGNORECASE)
149
  _CONFLICTS_HEADER_RE = re.compile(r"^\s*CONFLICTS\b.*:\s*$", re.IGNORECASE)
150
- _CONFLICTS_INLINE_NONE_RE = re.compile(r"^\s*CONFLICTS\s*:\s*not specified\.?\s*$", re.IGNORECASE)
 
 
 
151
 
152
  _MATCH_LINE_RE = re.compile(
153
  r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(matches reference:\s*(.+?)\)\s*$",
154
  re.IGNORECASE,
155
  )
 
156
  _CONFLICT_LINE_RE = re.compile(
157
  r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(conflicts reference:\s*(.+?)\)\s*$",
158
  re.IGNORECASE,
@@ -161,7 +165,10 @@ _CONFLICT_LINE_RE = re.compile(
161
  _GENERIC_BULLET_RE = re.compile(r"^\s*-\s*(.+?)\s*$")
162
 
163
 
164
- def _extract_key_traits_and_conflicts(shaped_ctx: str) -> tuple[list[str], list[str], bool]:
 
 
 
165
  key_traits: list[str] = []
166
  conflicts: list[str] = []
167
 
@@ -247,7 +254,7 @@ def _format_bullets(items: list[str], *, none_text: str) -> str:
247
 
248
 
249
  # ------------------------------------------------------------
250
- # NEW — CONFIDENCE STATE EXTRACTOR
251
  # ------------------------------------------------------------
252
 
253
  _CONF_STATE_RE = re.compile(r"^\s*-\s*Confidence State:\s*(.+?)\s*$", re.IGNORECASE)
@@ -274,7 +281,7 @@ def _extract_confidence_state(shaped_ctx: str) -> tuple[str | None, str | None]:
274
 
275
 
276
  # ------------------------------------------------------------
277
- # DETERMINISTIC TEMPLATES (PRIMARY AUTHORITY)
278
  # ------------------------------------------------------------
279
 
280
  def _template_conclusion(
@@ -292,12 +299,10 @@ def _template_conclusion(
292
  c_short = ", ".join(conflicts[:3]) if conflicts else None
293
 
294
  if confidence_state is None:
295
- # fallback to legacy deterministic behaviour
296
  return _deterministic_conclusion(g, key_traits, conflicts)
297
 
298
  cs = confidence_state.lower()
299
 
300
- # ---------- STRONG MATCH ----------
301
  if "strong" in cs:
302
  return (
303
  f"This phenotype is indicative of {g} with no conflicting traits observed. "
@@ -305,7 +310,6 @@ def _template_conclusion(
305
  f"{rec}".strip()
306
  )
307
 
308
- # ---------- PROBABLE MATCH ----------
309
  if "probable" in cs or "conflicts present" in cs or "cautious" in cs:
310
  if c_short:
311
  return (
@@ -319,26 +323,23 @@ def _template_conclusion(
319
  f"reduces certainty. {rec}".strip()
320
  )
321
 
322
- # ---------- INCONCLUSIVE ----------
323
  if "inconclusive" in cs or "conflicting" in cs:
324
  return (
325
  f"The top genus match is {g}; however, the phenotype is inconclusive due to conflicting "
326
  f"test results ({c_short or 'multiple conflicting traits'}). {rec}".strip()
327
  )
328
 
329
- # ---------- WEAK EVIDENCE ----------
330
  if "weak" in cs:
331
  return (
332
  f"The available phenotype provides weak evidence for {g}. "
333
  f"Additional testing or phenotype data is recommended. {rec}".strip()
334
  )
335
 
336
- # Unknown confidence state → safe fallback
337
  return _deterministic_conclusion(g, key_traits, conflicts)
338
 
339
 
340
  # ------------------------------------------------------------
341
- # LEGACY DETERMINISTIC CONCLUSION (FINAL BACKSTOP)
342
  # ------------------------------------------------------------
343
 
344
  def _deterministic_conclusion(genus: str, key_traits: list[str], conflicts: list[str]) -> str:
@@ -405,7 +406,6 @@ def generate_genus_rag_explanation(
405
  "No reference evidence was available to evaluate this genus against the observed phenotype."
406
  )
407
 
408
- # Extract evidence sections
409
  key_traits, conflicts, saw_headers = _extract_key_traits_and_conflicts(context)
410
 
411
  if (not saw_headers) or (not key_traits and not conflicts):
@@ -417,10 +417,8 @@ def generate_genus_rag_explanation(
417
  key_traits_text = _format_bullets(key_traits, none_text="- Not specified.")
418
  conflicts_text = _format_bullets(conflicts, none_text="- Not specified.")
419
 
420
- # ---------- NEW: read CONFIDENCE STATE ----------
421
  confidence_state, recommended_action = _extract_confidence_state(context)
422
 
423
- # Generate deterministic template conclusion first
424
  template_conclusion = _template_conclusion(
425
  genus_clean,
426
  confidence_state,
@@ -429,11 +427,10 @@ def generate_genus_rag_explanation(
429
  recommended_action,
430
  )
431
 
432
- # If no confidence state → skip LLM entirely
433
  if confidence_state is None:
434
  final_conclusion = template_conclusion
 
435
  else:
436
- # Ask LLM to paraphrase but not reinterpret
437
  prompt = RAG_PROMPT.format(
438
  genus=genus_clean,
439
  confidence_state=confidence_state,
@@ -456,7 +453,7 @@ def generate_genus_rag_explanation(
456
  **inputs,
457
  max_new_tokens=max_new_tokens,
458
  temperature=0.0,
459
- num_beams=1,
460
  do_sample=False,
461
  repetition_penalty=1.2,
462
  no_repeat_ngram_size=3,
@@ -469,7 +466,6 @@ def generate_genus_rag_explanation(
469
  _log_block("RAW OUTPUT (CONCLUSION)", decoded)
470
  _log_block("CLEANED OUTPUT", cleaned)
471
 
472
- # If LLM garbage → revert to deterministic template
473
  if _looks_like_echo_or_garbage(cleaned):
474
  final_conclusion = template_conclusion
475
  if RAG_GEN_LOG_OUTPUT:
@@ -477,7 +473,6 @@ def generate_genus_rag_explanation(
477
  else:
478
  final_conclusion = cleaned
479
 
480
- # ---------- Final structured output ----------
481
  final = (
482
  "KEY TRAITS:\n"
483
  f"{key_traits_text}\n\n"
 
29
  import os
30
  import re
31
  import torch
32
+ from transformers import BartForConditionalGeneration, BartTokenizer
33
 
34
 
35
  # ------------------------------------------------------------
 
38
 
39
  MODEL_NAME = "facebook/bart-large"
40
 
41
+ _tokenizer: BartTokenizer | None = None
42
+ _model: BartForConditionalGeneration | None = None
43
 
44
  _MAX_INPUT_TOKENS = 1020
45
  _DEFAULT_MAX_NEW_TOKENS = 256
 
47
  _CONTEXT_CHAR_CAP = 2400
48
 
49
 
50
+ def _get_model() -> tuple[BartTokenizer, BartForConditionalGeneration]:
51
  global _tokenizer, _model
52
  if _tokenizer is None or _model is None:
53
+ _tokenizer = BartTokenizer.from_pretrained(MODEL_NAME)
54
+ _model = BartForConditionalGeneration.from_pretrained(
55
+ MODEL_NAME,
56
+ device_map="auto",
57
+ torch_dtype=torch.float32,
58
+ )
59
  return _tokenizer, _model
60
 
61
 
 
136
  return True
137
 
138
  if "." not in s and "match" not in low and "conflict" not in low:
139
+ return True
140
 
141
  return False
142
 
 
147
 
148
  _KEY_MATCHES_HEADER_RE = re.compile(r"^\s*KEY MATCHES\s*:\s*$", re.IGNORECASE)
149
  _CONFLICTS_HEADER_RE = re.compile(r"^\s*CONFLICTS\b.*:\s*$", re.IGNORECASE)
150
+ _CONFLICTS_INLINE_NONE_RE = re.compile(
151
+ r"^\s*CONFLICTS\s*:\s*not specified\.?\s*$",
152
+ re.IGNORECASE,
153
+ )
154
 
155
  _MATCH_LINE_RE = re.compile(
156
  r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(matches reference:\s*(.+?)\)\s*$",
157
  re.IGNORECASE,
158
  )
159
+
160
  _CONFLICT_LINE_RE = re.compile(
161
  r"^\s*-\s*([^:]+)\s*:\s*(.+?)\s*\(conflicts reference:\s*(.+?)\)\s*$",
162
  re.IGNORECASE,
 
165
  _GENERIC_BULLET_RE = re.compile(r"^\s*-\s*(.+?)\s*$")
166
 
167
 
168
+ def _extract_key_traits_and_conflicts(
169
+ shaped_ctx: str,
170
+ ) -> tuple[list[str], list[str], bool]:
171
+
172
  key_traits: list[str] = []
173
  conflicts: list[str] = []
174
 
 
254
 
255
 
256
  # ------------------------------------------------------------
257
+ # CONFIDENCE STATE EXTRACTOR
258
  # ------------------------------------------------------------
259
 
260
  _CONF_STATE_RE = re.compile(r"^\s*-\s*Confidence State:\s*(.+?)\s*$", re.IGNORECASE)
 
281
 
282
 
283
  # ------------------------------------------------------------
284
+ # DETERMINISTIC TEMPLATES
285
  # ------------------------------------------------------------
286
 
287
  def _template_conclusion(
 
299
  c_short = ", ".join(conflicts[:3]) if conflicts else None
300
 
301
  if confidence_state is None:
 
302
  return _deterministic_conclusion(g, key_traits, conflicts)
303
 
304
  cs = confidence_state.lower()
305
 
 
306
  if "strong" in cs:
307
  return (
308
  f"This phenotype is indicative of {g} with no conflicting traits observed. "
 
310
  f"{rec}".strip()
311
  )
312
 
 
313
  if "probable" in cs or "conflicts present" in cs or "cautious" in cs:
314
  if c_short:
315
  return (
 
323
  f"reduces certainty. {rec}".strip()
324
  )
325
 
 
326
  if "inconclusive" in cs or "conflicting" in cs:
327
  return (
328
  f"The top genus match is {g}; however, the phenotype is inconclusive due to conflicting "
329
  f"test results ({c_short or 'multiple conflicting traits'}). {rec}".strip()
330
  )
331
 
 
332
  if "weak" in cs:
333
  return (
334
  f"The available phenotype provides weak evidence for {g}. "
335
  f"Additional testing or phenotype data is recommended. {rec}".strip()
336
  )
337
 
 
338
  return _deterministic_conclusion(g, key_traits, conflicts)
339
 
340
 
341
  # ------------------------------------------------------------
342
+ # BACKSTOP DETERMINISTIC CONCLUSION
343
  # ------------------------------------------------------------
344
 
345
  def _deterministic_conclusion(genus: str, key_traits: list[str], conflicts: list[str]) -> str:
 
406
  "No reference evidence was available to evaluate this genus against the observed phenotype."
407
  )
408
 
 
409
  key_traits, conflicts, saw_headers = _extract_key_traits_and_conflicts(context)
410
 
411
  if (not saw_headers) or (not key_traits and not conflicts):
 
417
  key_traits_text = _format_bullets(key_traits, none_text="- Not specified.")
418
  conflicts_text = _format_bullets(conflicts, none_text="- Not specified.")
419
 
 
420
  confidence_state, recommended_action = _extract_confidence_state(context)
421
 
 
422
  template_conclusion = _template_conclusion(
423
  genus_clean,
424
  confidence_state,
 
427
  recommended_action,
428
  )
429
 
 
430
  if confidence_state is None:
431
  final_conclusion = template_conclusion
432
+
433
  else:
 
434
  prompt = RAG_PROMPT.format(
435
  genus=genus_clean,
436
  confidence_state=confidence_state,
 
453
  **inputs,
454
  max_new_tokens=max_new_tokens,
455
  temperature=0.0,
456
+ num_beams=3, # BART benefits from small beam search
457
  do_sample=False,
458
  repetition_penalty=1.2,
459
  no_repeat_ngram_size=3,
 
466
  _log_block("RAW OUTPUT (CONCLUSION)", decoded)
467
  _log_block("CLEANED OUTPUT", cleaned)
468
 
 
469
  if _looks_like_echo_or_garbage(cleaned):
470
  final_conclusion = template_conclusion
471
  if RAG_GEN_LOG_OUTPUT:
 
473
  else:
474
  final_conclusion = cleaned
475
 
 
476
  final = (
477
  "KEY TRAITS:\n"
478
  f"{key_traits_text}\n\n"