manpreet88 commited on
Commit
30fd755
·
1 Parent(s): 8b37da0

Update orchestrator.py

Browse files
Files changed (1) hide show
  1. PolyAgent/orchestrator.py +174 -288
PolyAgent/orchestrator.py CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  import json
@@ -6,7 +20,7 @@ import sys
6
  from pathlib import Path
7
  from typing import Dict, Any, List, Optional, Tuple
8
  from urllib.parse import urlparse
9
- from typing import Optional
10
  import numpy as np
11
  import torch
12
  import torch.nn as nn
@@ -47,7 +61,7 @@ try:
47
  except Exception:
48
  spm = None
49
 
50
- # Optional: selfies (for SELFIES→SMILES/PSMILES conversion, as in G2)
51
  try:
52
  import selfies as sf
53
  except Exception:
@@ -57,11 +71,36 @@ RDKit_AVAILABLE = Chem is not None
57
  SELFIES_AVAILABLE = sf is not None
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # =============================================================================
61
  # DOI NORMALIZATION / RESOLUTION HELPERS
62
  # =============================================================================
63
  _DOI_RE = re.compile(r"^10\.\d{4,9}/\S+$", re.IGNORECASE)
64
 
 
65
  def normalize_doi(raw: str) -> Optional[str]:
66
  if not isinstance(raw, str):
67
  return None
@@ -75,10 +114,12 @@ def normalize_doi(raw: str) -> Optional[str]:
75
  s = s.rstrip(").,;]}")
76
  return s if _DOI_RE.match(s) else None
77
 
 
78
  def doi_to_url(doi: str) -> str:
79
  # doi is assumed normalized
80
  return f"https://doi.org/{doi}"
81
 
 
82
  def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
83
  """
84
  Best-effort resolver check. Keeps pipeline robust against dead/unregistered DOIs.
@@ -95,8 +136,9 @@ def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
95
  except Exception:
96
  return False
97
 
 
98
  # =============================================================================
99
- # CITATION / DOMAIN TAGGING HELPERS (domain-style citations like "(nature.com)")
100
  # =============================================================================
101
  def _url_to_domain(url: str) -> Optional[str]:
102
  if not isinstance(url, str) or not url.strip():
@@ -133,10 +175,12 @@ def _url_to_domain(url: str) -> Optional[str]:
133
  except Exception:
134
  return None
135
 
 
136
  def _attach_source_domains(obj: Any) -> Any:
137
  """
138
  Recursively add a short source_domain field where URLs are present.
139
- This enables domain-style citations like "(nature.com)".
 
140
  """
141
  if isinstance(obj, list):
142
  return [_attach_source_domains(x) for x in obj]
@@ -160,6 +204,7 @@ def _attach_source_domains(obj: Any) -> Any:
160
  def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
161
  """
162
  Add 'cite_tag' fields for citable web/RAG items using DOI-first URL tags.
 
163
  Requirement:
164
  - Paper citations must use the COMPLETE DOI URL (https://doi.org/...) as the bracket text.
165
  - If DOI is not available, fall back to the best http(s) URL.
@@ -181,7 +226,7 @@ def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
181
  return False
182
 
183
  def get_best_url(d: Dict[str, Any]) -> Optional[str]:
184
- # DOI-first: if DOI exists, ALWAYS prefer doi.org for citation link text and href.
185
  doi = normalize_doi(d.get("doi", ""))
186
  if doi:
187
  return doi_to_url(doi)
@@ -199,23 +244,16 @@ def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
199
  out = {k: walk_and_tag(v) for k, v in node.items()}
200
 
201
  if is_citable_item(out):
202
- # Citation tag MUST be DOI URL (preferred) or best URL (fallback).
203
  url = get_best_url(out)
204
  if isinstance(url, str) and url.startswith(("http://", "https://")):
205
- # If an existing cite_tag is non-URL (e.g., a domain tag), replace it.
206
  cur = out.get("cite_tag")
207
  if not (isinstance(cur, str) and cur.strip().startswith(("http://", "https://"))):
208
  out["cite_tag"] = url.strip()
209
- else:
210
- # If we cannot form a URL, leave as-is (should be rare due to is_citable_item).
211
- pass
212
 
213
- # Maintain a compact index (optional, harmless for UIs)
214
  url = get_best_url(out)
215
  dom = out.get("source_domain") or (_url_to_domain(url) if url else None) or "source"
216
  citation_index["sources"].append(
217
  {
218
- # tag is the bracket text requirement (DOI URL or URL)
219
  "tag": out.get("cite_tag") if isinstance(out.get("cite_tag"), str) else url,
220
  "domain": dom,
221
  "title": out.get("title") or out.get("name") or "Untitled",
@@ -236,10 +274,9 @@ def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
236
  return report
237
 
238
 
239
-
240
- # =============================================================================
241
- # ENFORCE INLINE CLICKABLE LITERATURE CITATIONS (distributed, not clustered)
242
- # =============================================================================
243
  _CITE_COUNT_PATTERNS = [
244
  r"(?:at\s+least\s+)?(\d{1,3})\s*(?:citations|citation|papers|paper|sources|source|references|reference)\b",
245
  r"\bcite\s+(\d{1,3})\s*(?:papers|paper|sources|source|references|reference|citations|citation)\b",
@@ -262,8 +299,8 @@ def _infer_required_citation_count(text: str, default_n: int = 10) -> int:
262
 
263
  def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[str, str]]:
264
  """
265
- Return unique (domain, url) pairs from report['citation_index']['sources'].
266
- Link text is strictly the root domain. URL must be http(s).
267
  """
268
  out: List[Tuple[str, str]] = []
269
  seen: set = set()
@@ -280,7 +317,6 @@ def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[st
280
  url = s.get("url")
281
  if not isinstance(url, str) or not url.startswith(("http://", "https://")):
282
  continue
283
- # cite_text is DOI URL (tag) if present; else fall back to the URL itself.
284
  cite_text = s.get("tag") if isinstance(s.get("tag"), str) and s.get("tag").strip() else url
285
  if not isinstance(cite_text, str) or not cite_text.strip():
286
  cite_text = url
@@ -310,18 +346,15 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
310
  if not citations:
311
  return md
312
 
313
- # Count existing literature links by URL (any markdown link).
314
  existing_urls = set(re.findall(r"\[[^\]]+\]\((https?://[^)]+)\)", md))
315
  need = max(0, int(min_needed) - len(existing_urls))
316
  if need <= 0:
317
  return md
318
 
319
- # Only use citations not already present.
320
  remaining: List[Tuple[str, str]] = [(d, u) for (d, u) in citations if u not in existing_urls]
321
  if not remaining:
322
  return md
323
 
324
- # Split by fenced code blocks; do not inject inside them.
325
  parts = re.split(r"(```[\s\S]*?```)", md)
326
  rem_i = 0
327
 
@@ -331,7 +364,6 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
331
  if part.startswith("```") and part.endswith("```"):
332
  continue
333
 
334
- # Split into paragraph blocks (preserve blank-line separators).
335
  segs = re.split(r"(\n\s*\n)", part)
336
  for si in range(0, len(segs), 2):
337
  if rem_i >= len(remaining) or need <= 0:
@@ -339,27 +371,24 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
339
  para = segs[si]
340
  if not isinstance(para, str) or not para.strip():
341
  continue
342
- # Skip headings.
343
  if para.lstrip().startswith("#"):
344
  continue
345
- # Skip paragraphs that already contain at least one markdown link.
346
  if re.search(r"\[[^\]]+\]\((https?://[^)]+)\)", para):
347
  continue
348
-
349
- # Prefer injecting into evidence-bearing paragraphs first to avoid "clutter".
350
- # If paragraph doesn't look like a literature-backed claim, skip it in this pass.
351
- if not re.search(r"\b(reported|shown|demonstrated|study|studies|literature|evidence|review|according)\b", para, flags=re.IGNORECASE):
 
352
  continue
353
 
354
  cite_text, url = remaining[rem_i]
355
- # Requirement: bracket text is the COMPLETE DOI URL (or URL fallback).
356
  segs[si] = para.rstrip() + f" [{cite_text}]({url})"
357
  rem_i += 1
358
  need -= 1
359
 
360
  parts[pi] = "".join(segs)
361
 
362
- # Second pass (if still need citations): allow any non-heading paragraph without links.
363
  if need > 0 and rem_i < len(remaining):
364
  md2 = "".join(parts)
365
  parts2 = re.split(r"(```[\s\S]*?```)", md2)
@@ -391,9 +420,9 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
391
 
392
  def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> str:
393
  """
394
- Enforce the single citation requirement:
395
- - Link text must be the COMPLETE DOI URL (preferred) or URL fallback.
396
- - Each DOI/URL must appear at most once in the entire answer.
397
  Only operates outside fenced code blocks.
398
  """
399
  if not isinstance(md, str) or not md.strip():
@@ -401,7 +430,6 @@ def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> s
401
  if not isinstance(report, dict):
402
  return md
403
 
404
- # Build url -> preferred_text mapping (DOI URL / URL)
405
  url_to_text: Dict[str, str] = {}
406
  ci = report.get("citation_index", {})
407
  sources = ci.get("sources") if isinstance(ci, dict) else None
@@ -421,22 +449,19 @@ def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> s
421
 
422
  def _rewrite_and_dedupe(text: str) -> str:
423
  def repl(m: re.Match) -> str:
424
- txt = m.group(1)
425
  url = m.group(2).strip()
426
  if url in seen_urls:
427
- # remove duplicate citation entirely (and any leading space before it if present)
428
  return ""
429
  seen_urls.add(url)
430
  pref = url_to_text.get(url, url)
431
  return f"[{pref}]({url})"
432
- # Rewrite link text to preferred, then dedupe by URL
433
  return re.sub(r"\[([^\]]+)\]\((https?://[^)]+)\)", repl, text)
434
 
435
  for i, part in enumerate(parts):
436
  if part.startswith("```") and part.endswith("```"):
437
  continue
438
  parts[i] = _rewrite_and_dedupe(part)
439
- # Cleanup: collapse double spaces created by removals
440
  parts[i] = re.sub(r"[ \t]{2,}", " ", parts[i])
441
  parts[i] = re.sub(r"\n{3,}", "\n\n", parts[i])
442
 
@@ -446,7 +471,6 @@ def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> s
446
  def autolink_doi_urls(md: str) -> str:
447
  """
448
  Wrap bare DOI URLs in Markdown links outside code blocks.
449
- Prevents plain DOI URLs from rendering as non-clickable text.
450
  """
451
  if not md:
452
  return md
@@ -457,15 +481,18 @@ def autolink_doi_urls(md: str) -> str:
457
  parts[i] = re.sub(
458
  r"(?<!\]\()(?P<u>https?://doi\.org/10\.\d{4,9}/[^\s\)\],;]+)",
459
  lambda m: f"[{m.group('u')}]({m.group('u')})",
460
- part,
461
  flags=re.IGNORECASE,
462
  )
463
  return "".join(parts)
464
 
 
 
 
 
465
  def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
466
  """
467
- Ensure each tool output has a [T#] tag for tool-citation style.
468
- This does NOT modify tool outputs beyond adding a 'cite_tag' key when missing.
469
  """
470
  if not isinstance(report, dict):
471
  return report
@@ -474,8 +501,7 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
474
  if not isinstance(tool_outputs, dict):
475
  return report
476
 
477
- # Stable order (common core tools first)
478
- ordered_tools = [
479
  "data_extraction",
480
  "cl_encoding",
481
  "property_prediction",
@@ -485,12 +511,10 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
485
  "report_generation",
486
  ]
487
 
488
- # Tag assignment: keep existing cite_tags if present
489
  tool_tag_map: Dict[str, str] = {}
490
  tag = "[T]"
491
 
492
- # First pass: assign in preferred order
493
- for tool in ordered_tools:
494
  node = tool_outputs.get(tool)
495
  if node is None:
496
  continue
@@ -498,7 +522,6 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
498
  if isinstance(node, dict) and not node.get("cite_tag"):
499
  node["cite_tag"] = tag
500
 
501
- # Second pass: any remaining tools in tool_outputs
502
  for tool, node in tool_outputs.items():
503
  if tool in tool_tag_map or node is None:
504
  continue
@@ -506,11 +529,9 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
506
  if isinstance(node, dict) and not node.get("cite_tag"):
507
  node["cite_tag"] = tag
508
 
509
- # Also tag summary nodes (best-effort, no structural assumptions)
510
  try:
511
  summary = report.get("summary", {}) or {}
512
  if isinstance(summary, dict):
513
- # common mapping
514
  key_to_tool = {
515
  "data_extraction": "data_extraction",
516
  "cl_encoding": "cl_encoding",
@@ -534,7 +555,7 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
534
 
535
  def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
536
  """
537
- Render tool outputs as verbatim JSON blocks (no tweaking of content values).
538
  """
539
  if not isinstance(report, dict):
540
  return ""
@@ -543,7 +564,6 @@ def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
543
  if not isinstance(tool_outputs, dict):
544
  return ""
545
 
546
- # Prefer a stable display order; include any extra keys afterward
547
  preferred = [
548
  "data_extraction",
549
  "cl_encoding",
@@ -571,28 +591,19 @@ def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
571
 
572
 
573
  # =============================================================================
574
- # PICKLE / JOBLIB COMPATIBILITY SHIMS (Fix generator loading error)
575
  # =============================================================================
576
  class LatentPropertyModel:
577
  """
578
  Compatibility shim for joblib/pickle artifacts saved with references like:
579
  __main__.LatentPropertyModel
580
-
581
- The original training code likely defined this in a script, so pickle recorded it under __main__.
582
- When loading from Gradio, __main__ is different, so unpickling fails.
583
-
584
- This shim is intentionally minimal:
585
- - pickle will restore attributes into this object
586
- - predict(...) attempts to delegate to a plausible underlying model attribute if present
587
  """
588
  def predict(self, X):
589
- # Common patterns: wrapper stores underlying estimator under one of these attributes.
590
  for attr in ("model", "gpr", "gpr_model", "estimator", "predictor", "_model", "_gpr"):
591
  if hasattr(self, attr):
592
  obj = getattr(self, attr)
593
  if hasattr(obj, "predict"):
594
  return obj.predict(X)
595
- # If the wrapper itself has been restored with a custom predict, this will never be hit.
596
  raise AttributeError(
597
  "LatentPropertyModel shim could not find an underlying predictor. "
598
  "Artifact expects a wrapped model attribute with a .predict method."
@@ -602,7 +613,6 @@ class LatentPropertyModel:
602
  def _install_unpickle_shims() -> None:
603
  """
604
  Ensure that any classes pickled under __main__ are available at load time.
605
- This is critical for joblib artifacts created from scripts (training/fit scripts).
606
  """
607
  main_mod = sys.modules.get("__main__")
608
  if main_mod is not None and not hasattr(main_mod, "LatentPropertyModel"):
@@ -620,7 +630,6 @@ def _safe_joblib_load(path: str):
620
  return joblib.load(path)
621
  except Exception as e:
622
  msg = str(e)
623
- # Targeted fix for your exact failure mode
624
  if "Can't get attribute 'LatentPropertyModel' on <module '__main__'" in msg:
625
  _install_unpickle_shims()
626
  return joblib.load(path)
@@ -628,18 +637,50 @@ def _safe_joblib_load(path: str):
628
 
629
 
630
  # =============================================================================
631
- # PATHS (per your 5M pipeline artifacts)
632
  # =============================================================================
633
- DOWNSTREAM_BESTWEIGHTS_5M_DIR = "/home/kaur-m43/multimodal_downstream_bestweights_5M"
634
- INVERSE_DESIGN_5M_DIR = "/home/kaur-m43/multimodal_inverse_design_output_5M_polybart_style/best_models"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
 
636
 
637
  # =============================================================================
638
- # Property name canonicalization
639
  # =============================================================================
640
  def canonical_property_name(name: str) -> str:
641
  """
642
- Map user/tool inputs to the canonical keys used in PROPERTY_HEAD_PATHS/GENERATOR_DIRS.
643
  """
644
  if not isinstance(name, str):
645
  return ""
@@ -663,15 +704,11 @@ def canonical_property_name(name: str) -> str:
663
  return aliases.get(s, s)
664
 
665
 
666
- # =============================================================================
667
- # NEW: best-effort inference of property + target_value from questions text
668
- # (used only when callers omit property/target_value but provide questions)
669
- # =============================================================================
670
  _NUM_RE = r"[-+]?\d+(?:\.\d+)?"
671
 
 
672
  def infer_property_from_text(text: str) -> Optional[str]:
673
  s = (text or "").lower()
674
- # explicit "property: ..."
675
  m = re.search(r"\bproperty\b\s*[:=]\s*([a-zA-Z _-]+)", s)
676
  if m:
677
  cand = m.group(1).strip().lower()
@@ -698,6 +735,7 @@ def infer_property_from_text(text: str) -> Optional[str]:
698
  return "density"
699
  return None
700
 
 
701
  def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[float]:
702
  sl = (text or "").lower()
703
 
@@ -729,7 +767,6 @@ def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[flo
729
  except Exception:
730
  pass
731
 
732
- # token-near-number fallback (within 80 chars)
733
  tokens = []
734
  if prop == "glass transition":
735
  tokens = ["tg", "glass transition"]
@@ -754,33 +791,9 @@ def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[flo
754
 
755
  return None
756
 
757
- PROPERTY_HEAD_PATHS = {
758
- "density": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "density", "best_run_checkpoint.pt"),
759
- "glass transition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "glass_transition", "best_run_checkpoint.pt"),
760
- "melting": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "melting", "best_run_checkpoint.pt"),
761
- "specific volume": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "specific_volume", "best_run_checkpoint.pt"),
762
- "thermal decomposition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "thermal_decomposition", "best_run_checkpoint.pt"),
763
- }
764
-
765
- PROPERTY_HEAD_META = {
766
- "density": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "density", "best_run_metadata.json"),
767
- "glass transition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "glass_transition", "best_run_metadata.json"),
768
- "melting": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "melting", "best_run_metadata.json"),
769
- "specific volume": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "specific_volume", "best_run_metadata.json"),
770
- "thermal decomposition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "thermal_decomposition", "best_run_metadata.json"),
771
- }
772
-
773
- GENERATOR_DIRS = {
774
- "density": os.path.join(INVERSE_DESIGN_5M_DIR, "density"),
775
- "glass transition": os.path.join(INVERSE_DESIGN_5M_DIR, "glass_transition"),
776
- "melting": os.path.join(INVERSE_DESIGN_5M_DIR, "melting"),
777
- "specific volume": os.path.join(INVERSE_DESIGN_5M_DIR, "specific_volume"),
778
- "thermal decomposition": os.path.join(INVERSE_DESIGN_5M_DIR, "thermal_decomposition"),
779
- }
780
-
781
 
782
  # =============================================================================
783
- # Tokenizers (SentencePiece etc. — kept for backward compatibility)
784
  # =============================================================================
785
  class SimpleCharTokenizer:
786
  def __init__(self, vocab_chars: List[str], special_tokens=("<pad>", "<s>", "</s>", "<unk>")):
@@ -840,7 +853,6 @@ class SentencePieceTokenizerWrapper:
840
  blocked.append(tid)
841
  setattr(self, "_blocked_ids", blocked)
842
 
843
- # Safety: require '*' token
844
  if self.PieceToId("*") is None:
845
  raise RuntimeError("SentencePiece tokenizer loaded but '*' token not found – aborting for safe PSMILES generation.")
846
 
@@ -878,16 +890,18 @@ def psmiles_to_rdkit_smiles(psmiles: str) -> str:
878
  s = re.sub(r"\*", "[*]", s)
879
  return s
880
 
881
- # --- UI-safe endpoint normalization (ONLY [At]/[AT] -> [*]) ---
882
  _AT_BRACKET_UI_RE = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
883
 
 
884
  def replace_at_with_star(psmiles: str) -> str:
885
  if not isinstance(psmiles, str) or not psmiles:
886
  return psmiles
887
  return _AT_BRACKET_UI_RE.sub("[*]", psmiles)
888
 
 
889
  # =============================================================================
890
- # SELFIES utilities (minimal subset mirroring G2.py behaviour)
891
  # =============================================================================
892
  _SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]")
893
 
@@ -933,32 +947,25 @@ def selfies_to_smiles(selfies_str: str) -> str:
933
  def pselfies_to_psmiles(selfies_str: str) -> str:
934
  """
935
  For this orchestrator we treat pSELFIES→PSMILES as SELFIES→canonical SMILES.
936
- The G2 training script used a more elaborate At/[*] polymer mapping; if you
937
- want that exact behaviour, you can replace this with the full pselfies_to_psmiles
938
- utilities from G2.py.
939
  """
940
  return selfies_to_smiles(selfies_str)
941
 
942
 
943
  # =============================================================================
944
- # SELFIES-TED decoder (as in G2.py, but simplified to core functionality)
945
  # =============================================================================
946
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
947
  SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted")
948
 
949
- # Generation hyperparameters (mirroring G2 defaults)
950
  GEN_MAX_LEN = 256
951
  GEN_MIN_LEN = 10
952
  GEN_TOP_P = 0.92
953
  GEN_TEMPERATURE = 1.0
954
  GEN_REPETITION_PENALTY = 1.05
955
- LATENT_NOISE_STD_GEN = 0.15 # default exploration std for generation
956
 
957
 
958
  def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
959
- """
960
- Small helper to make HF loading more robust, copied from G2 spirit.
961
- """
962
  import time
963
  last_err = None
964
  for t in range(max_tries):
@@ -974,7 +981,7 @@ def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
974
 
975
  def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
976
  """
977
- Load tokenizer + seq2seq model for SELFIES-TED, exactly as in G2.py (but without side effects).
978
  """
979
  def _load_tok():
980
  return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True)
@@ -989,8 +996,7 @@ def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
989
 
990
  class CLConditionedSelfiesTEDGenerator(nn.Module):
991
  """
992
- Same structure as in G2.py: take a CL embedding (latent) and project it
993
- into a fixed-length memory that conditions a SELFIES-TED seq2seq model.
994
  """
995
  def __init__(self, tok, seq2seq_model, cl_emb_dim: int = 600, mem_len: int = 4):
996
  super().__init__()
@@ -1038,9 +1044,6 @@ class CLConditionedSelfiesTEDGenerator(nn.Module):
1038
  temperature: float = GEN_TEMPERATURE,
1039
  repetition_penalty: float = GEN_REPETITION_PENALTY,
1040
  ) -> List[str]:
1041
- """
1042
- Latent→pSELFIES generation, as in G2.
1043
- """
1044
  self.eval()
1045
  z = z.to(next(self.parameters()).device)
1046
  enc_out, attn = self.build_encoder_outputs(z)
@@ -1063,25 +1066,17 @@ class CLConditionedSelfiesTEDGenerator(nn.Module):
1063
 
1064
 
1065
  # =============================================================================
1066
- # Latent→property helper (uses G2-style LatentPropertyModel joblib artifacts)
1067
  # =============================================================================
1068
  def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
1069
- """
1070
- Mirror G2's predict_latent_property(model, z):
1071
- - PCA transform if present
1072
- - GPR predict (scaled y)
1073
- - inverse-transform via y_scaler if present
1074
- """
1075
  z_use = np.asarray(z, dtype=np.float32)
1076
  if z_use.ndim == 1:
1077
  z_use = z_use.reshape(1, -1)
1078
 
1079
- # Optional PCA
1080
  pca = getattr(latent_model, "pca", None)
1081
  if pca is not None:
1082
  z_use = pca.transform(z_use.astype(np.float32))
1083
 
1084
- # GPR or wrapped predictor
1085
  gpr = getattr(latent_model, "gpr", None)
1086
  if gpr is not None and hasattr(gpr, "predict"):
1087
  y_s = gpr.predict(z_use)
@@ -1092,7 +1087,6 @@ def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarr
1092
 
1093
  y_s = np.array(y_s, dtype=np.float32).reshape(-1)
1094
 
1095
- # Optional scaler to get back to original units
1096
  y_scaler = getattr(latent_model, "y_scaler", None)
1097
  if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
1098
  y_u = y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1)
@@ -1103,7 +1097,7 @@ def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarr
1103
 
1104
 
1105
  # =============================================================================
1106
- # Legacy models (kept for backward compatibility; not used in new generation path)
1107
  # =============================================================================
1108
  class TransformerDecoderOnly(nn.Module):
1109
  def __init__(
@@ -1197,21 +1191,23 @@ class InverseDesignDecoder(nn.Module):
1197
 
1198
 
1199
  # =============================================================================
1200
- # Config
1201
  # =============================================================================
1202
  class OrchestratorConfig:
1203
- def __init__(self):
 
 
1204
  self.base_dir = "."
1205
- self.cl_weights_path = "/home/kaur-m43/multimodal_output_5M/best/pytorch_model.bin"
1206
- self.chroma_db_path = "chroma_polymer_db_big"
1207
  self.rag_embedding_model = "text-embedding-3-small"
1208
 
1209
  self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
1210
  self.model = os.getenv("OPENAI_MODEL", "gpt-4.1")
1211
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1212
 
1213
- self.spm_model_path = "/home/kaur-m43/spm_5M.model"
1214
- self.spm_vocab_path = "/home/kaur-m43/spm_5M.vocab"
1215
 
1216
  self.springer_api_key = os.getenv("SPRINGER_NATURE_API_KEY", "")
1217
  self.semantic_scholar_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
@@ -1223,7 +1219,7 @@ class OrchestratorConfig:
1223
  "property_prediction": True,
1224
  "polymer_generation": True,
1225
  "web_search": True,
1226
- "report_generation": True, # <-- FIX: required by the Gradio interface
1227
  "mol_render": True,
1228
  "gen_grid": True,
1229
  "prop_attribution": True,
@@ -1268,7 +1264,7 @@ TOOL_DESCRIPTIONS = {
1268
  "CrossRef, OpenAlex, EuropePMC, arXiv, Semantic Scholar, Springer Nature (API key), Internet Archive"
1269
  ),
1270
  },
1271
- "report_generation": { # <-- FIX: required by the Gradio interface
1272
  "name": "Report Generation",
1273
  "description": (
1274
  "Synthesizes available tool outputs into a single structured report object "
@@ -1299,6 +1295,10 @@ TOOL_DESCRIPTIONS = {
1299
  class PolymerOrchestrator:
1300
  def __init__(self, config: OrchestratorConfig):
1301
  self.config = config
 
 
 
 
1302
  self._openai_client = None
1303
  self._openai_unavailable_reason = None
1304
  self._data_extractor = None
@@ -1315,6 +1315,9 @@ class PolymerOrchestrator:
1315
 
1316
  self.system_prompt = self._build_system_prompt()
1317
 
 
 
 
1318
  @property
1319
  def openai_client(self):
1320
  if self._openai_client is None:
@@ -1336,7 +1339,7 @@ class PolymerOrchestrator:
1336
  return (
1337
  "You are the tool-planning module for **PolyAgent**, a polymer-science agent.\n"
1338
  "Your job is to inspect the user's questions and decide which tools\n"
1339
- "to run in which order. \n\n"
1340
  "Critical tool dependencies:\n"
1341
  "- property_prediction should run AFTER cl_encoding when possible and should reuse cl_encoding.embedding.\n"
1342
  "- polymer_generation is inverse-design and REQUIRES target_value (property -> PSMILES).\n\n"
@@ -1345,7 +1348,7 @@ class PolymerOrchestrator:
1345
  )
1346
 
1347
  # =============================================================================
1348
- # Planner: LLM tool-calling (no rule-based planner)
1349
  # =============================================================================
1350
  def analyze_query(self, user_query: str) -> Dict[str, Any]:
1351
  schema_keys = ["analysis", "tools_required", "execution_plan"]
@@ -1394,7 +1397,6 @@ class PolymerOrchestrator:
1394
  }
1395
  }
1396
 
1397
- # Preferred: function/tool-calling
1398
  try:
1399
  response = self.openai_client.chat.completions.create(
1400
  model=self.config.model,
@@ -1420,7 +1422,6 @@ class PolymerOrchestrator:
1420
 
1421
  raise RuntimeError("Tool-calling plan not returned; falling back to JSON mode.")
1422
  except Exception:
1423
- # Safe fallback: JSON response_format (still LLM-generated, not rule-based)
1424
  try:
1425
  response = self.openai_client.chat.completions.create(
1426
  model=self.config.model,
@@ -1466,7 +1467,7 @@ class PolymerOrchestrator:
1466
  output = self._run_polymer_generation(step, intermediate_data)
1467
  elif tool_name == "web_search":
1468
  output = self._run_web_search(step, intermediate_data)
1469
- elif tool_name == "report_generation": # <-- FIX
1470
  output = self._run_report_generation(step, intermediate_data)
1471
  elif tool_name == "mol_render":
1472
  output = self._run_mol_render(step, intermediate_data)
@@ -1580,7 +1581,6 @@ class PolymerOrchestrator:
1580
  "year": meta.get("year", ""),
1581
  "source": meta.get("source", meta.get("source_path", "")),
1582
  "venue": meta.get("venue", meta.get("journal", "")),
1583
- # NEW: preserve citable identifiers when present in metadata
1584
  "url": meta.get("url") or meta.get("link") or meta.get("href") or "",
1585
  "doi": meta.get("doi") or "",
1586
  })
@@ -1705,9 +1705,10 @@ class PolymerOrchestrator:
1705
  "attention_mask": torch.ones(1, 2048, dtype=torch.bool, device=self.config.device),
1706
  }
1707
 
1708
- # psmiles tokenization for psmiles encoder
1709
  if self._psmiles_tokenizer is None:
1710
  try:
 
1711
  self._psmiles_tokenizer = build_psmiles_tokenizer()
1712
  except Exception:
1713
  self._psmiles_tokenizer = None
@@ -1734,7 +1735,6 @@ class PolymerOrchestrator:
1734
  with torch.no_grad():
1735
  embeddings_dict = self._cl_encoder.encode(batch_mods)
1736
 
1737
- # enforce that all four modalities are present (gine, schnet, fp, psmiles)
1738
  required_modalities = ("gine", "schnet", "fp", "psmiles")
1739
  missing = [m for m in required_modalities if m not in embeddings_dict]
1740
  if missing:
@@ -1757,8 +1757,8 @@ class PolymerOrchestrator:
1757
  import torch.nn as nn
1758
 
1759
  property_name = canonical_property_name(property_name)
1760
- prop_ckpt = PROPERTY_HEAD_PATHS.get(property_name)
1761
- prop_meta = PROPERTY_HEAD_META.get(property_name)
1762
 
1763
  if prop_ckpt is None:
1764
  raise ValueError(f"No property head registered for: {property_name}")
@@ -1778,7 +1778,6 @@ class PolymerOrchestrator:
1778
 
1779
  ckpt = torch.load(prop_ckpt, map_location=self.config.device, weights_only=False)
1780
 
1781
- # locate state dict
1782
  state_dict = None
1783
  for k in ("state_dict", "model_state_dict", "model_state", "head_state_dict", "regressor_state_dict"):
1784
  if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
@@ -1804,7 +1803,6 @@ class PolymerOrchestrator:
1804
 
1805
  head = RegressionHeadOnly(hidden_dim=600, dropout=float(meta.get("dropout", 0.1))).to(self.config.device)
1806
 
1807
- # normalize key prefixes
1808
  normalized = {}
1809
  for k, v in state_dict.items():
1810
  nk = k
@@ -1825,7 +1823,6 @@ class PolymerOrchestrator:
1825
  head.load_state_dict(normalized, strict=False)
1826
  head.eval()
1827
 
1828
- # y scaler
1829
  y_scaler = None
1830
  if isinstance(ckpt, dict):
1831
  for sk in ("y_scaler", "scaler_y", "target_scaler", "y_normalizer"):
@@ -1852,16 +1849,14 @@ class PolymerOrchestrator:
1852
  return {"error": "Specify property name"}
1853
 
1854
  property_name = canonical_property_name(property_name)
1855
- if property_name not in PROPERTY_HEAD_PATHS:
1856
  return {"error": f"Unsupported property: {property_name}"}
1857
 
1858
- # Prefer embedding from cl_encoding output if available
1859
  emb_from_cl = None
1860
  cl = data.get("cl_encoding", None)
1861
  if isinstance(cl, dict) and isinstance(cl.get("embedding"), list) and len(cl["embedding"]) == 600:
1862
  emb_from_cl = torch.tensor([cl["embedding"]], dtype=torch.float32, device=self.config.device)
1863
 
1864
- # If no embedding provided, compute via extraction + CL
1865
  multimodal = data.get("data_extraction", None)
1866
  psmiles = data.get("psmiles", data.get("smiles", None))
1867
  if emb_from_cl is None:
@@ -1879,18 +1874,16 @@ class PolymerOrchestrator:
1879
  with torch.no_grad():
1880
  embs = self._cl_encoder.encode(batch_mods)
1881
 
1882
- # enforce all four modalities
1883
  required_modalities = ("gine", "schnet", "fp", "psmiles")
1884
  missing = [m for m in required_modalities if m not in embs]
1885
  if missing:
1886
  return {"error": f"CL encoder did not return embeddings for modalities: {', '.join(missing)}"}
1887
 
1888
  all_embs = [embs[k] for k in required_modalities]
1889
- emb_from_cl = torch.stack(all_embs, dim=0).mean(dim=0) # (B,600)
1890
  except Exception as e:
1891
  return {"error": f"Failed to compute CL embedding: {e}"}
1892
 
1893
- # Predict
1894
  try:
1895
  head, y_scaler, meta, ckpt_path = self._load_property_head(property_name)
1896
  with torch.no_grad():
@@ -1898,27 +1891,21 @@ class PolymerOrchestrator:
1898
 
1899
  pred_value = float(pred_norm)
1900
 
1901
- # 1) Preferred: inverse_transform using the actual scaler object if available
1902
  if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
1903
  try:
1904
  inv = y_scaler.inverse_transform(np.array([[pred_norm]], dtype=float))
1905
  pred_value = float(inv[0][0])
1906
  except Exception:
1907
  pred_value = float(pred_norm)
1908
-
1909
- # 2) Fallback: use metadata params if scaler object is missing
1910
  else:
1911
  mean = (meta or {}).get("scaler_mean", None)
1912
  scale = (meta or {}).get("scaler_scale", None)
1913
-
1914
- # StandardScaler inverse: x = x_scaled * scale + mean
1915
  try:
1916
  if isinstance(mean, list) and isinstance(scale, list) and len(mean) == 1 and len(scale) == 1:
1917
  pred_value = float(pred_norm) * float(scale[0]) + float(mean[0])
1918
  except Exception:
1919
  pred_value = float(pred_norm)
1920
 
1921
- # best-effort psmiles context
1922
  out_psmiles = None
1923
  if isinstance(multimodal, dict):
1924
  out_psmiles = multimodal.get("canonical_psmiles")
@@ -1933,7 +1920,7 @@ class PolymerOrchestrator:
1933
  "predictions": {property_name: pred_value},
1934
  "prediction_normalized": float(pred_norm),
1935
  "head_checkpoint_path": ckpt_path,
1936
- "metadata_path": PROPERTY_HEAD_META.get(property_name, ""),
1937
  "normalization_applied": bool(
1938
  (y_scaler is not None and hasattr(y_scaler, "inverse_transform")) or
1939
  ((meta or {}).get("scaler_mean") is not None and (meta or {}).get("scaler_scale") is not None)
@@ -1943,11 +1930,8 @@ class PolymerOrchestrator:
1943
  except Exception as e:
1944
  return {"error": f"Property prediction failed: {e}"}
1945
 
1946
- # ----------------- Inverse-design generator (NEW: CL + SELFIES-TED, as in G2) ----------------- #
1947
  def _get_selfies_ted_backend(self, model_name: str) -> Tuple[Any, Any]:
1948
- """
1949
- Cache and return (tokenizer, model) for a given SELFIES-TED model name.
1950
- """
1951
  if not model_name:
1952
  model_name = SELFIES_TED_MODEL_NAME
1953
  if model_name in self._selfies_ted_cache:
@@ -1958,18 +1942,11 @@ class PolymerOrchestrator:
1958
  return tok, model
1959
 
1960
  def _load_property_generator(self, property_name: str):
1961
- """
1962
- Load PolyBART-style inverse-design artifacts produced by G2.py:
1963
- - decoder_best_fold*.pt : state_dict of CLConditionedSelfiesTEDGenerator
1964
- - standardscaler_*.joblib : StandardScaler on property values
1965
- - gpr_psmiles_*.joblib : LatentPropertyModel (z->property)
1966
- - meta.json : meta info (selfies_ted_model, cl_emb_dim, mem_len, tol_scaled, ...)
1967
- """
1968
  property_name = canonical_property_name(property_name)
1969
  if property_name in self._property_generators:
1970
  return self._property_generators[property_name]
1971
 
1972
- base_dir = GENERATOR_DIRS.get(property_name)
1973
  if base_dir is None:
1974
  raise ValueError(f"No generator registered for: {property_name}")
1975
  if not os.path.isdir(base_dir):
@@ -2017,12 +1994,10 @@ class PolymerOrchestrator:
2017
  if not gpr_path or not os.path.exists(gpr_path):
2018
  raise FileNotFoundError(f"GPR *.joblib not found in {base_dir}")
2019
 
2020
- # Latent property model and scaler (G2-style LatentPropertyModel)
2021
  _install_unpickle_shims()
2022
- scaler_y = _safe_joblib_load(scaler_path) # StandardScaler on property
2023
- latent_prop_model = _safe_joblib_load(gpr_path) # should be LatentPropertyModel dataclass-like
2024
 
2025
- # SELFIES-TED backbone
2026
  selfies_ted_name = meta.get("selfies_ted_model", SELFIES_TED_MODEL_NAME)
2027
  tok, selfies_backbone = self._get_selfies_ted_backend(selfies_ted_name)
2028
 
@@ -2037,7 +2012,6 @@ class PolymerOrchestrator:
2037
  ).to(self.config.device)
2038
 
2039
  ckpt = torch.load(decoder_path, map_location=self.config.device, weights_only=False)
2040
- # In G2, decoder_best_fold*.pt is a plain state_dict; keep robust fallback
2041
  state_dict = None
2042
  if isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
2043
  state_dict = ckpt
@@ -2077,17 +2051,10 @@ class PolymerOrchestrator:
2077
  latent_noise_std: float = LATENT_NOISE_STD_GEN,
2078
  extra_factor: int = 8,
2079
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
2080
- """
2081
- Simple PolyBART-style latent sampler:
2082
- - if seed_latents provided, sample Gaussian noise around them and L2-normalize
2083
- - else, sample random latents on unit hypersphere
2084
- - score via latent_prop_model (z->property), keep those near target.
2085
- """
2086
  def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
2087
  n = np.linalg.norm(x, axis=-1, keepdims=True)
2088
  return x / np.clip(n, eps, None)
2089
 
2090
- # target in scaled space
2091
  if y_scaler is not None and hasattr(y_scaler, "transform"):
2092
  target_s = float(y_scaler.transform(np.array([[target_value]], dtype=np.float32))[0, 0])
2093
  else:
@@ -2127,23 +2094,14 @@ class PolymerOrchestrator:
2127
 
2128
  @torch.no_grad()
2129
  def _run_polymer_generation(self, step: Dict, data: Dict) -> Dict:
2130
- """
2131
- Inverse-design generation (CL latent → pSELFIES via SELFIES-TED → PSMILES).
2132
-
2133
- Corrections implemented:
2134
- 1) ONLY return RDKit-valid generated outputs (filter invalid candidates).
2135
- 2) Replace bracketed [At]/[AT]/[aT]/... with [*] AFTER the RDKit validity check
2136
- but BEFORE writing the response payload.
2137
- """
2138
  property_name = data.get("property", data.get("property_name", None))
2139
  if property_name is None:
2140
  return {"error": "Specify property name for generation"}
2141
 
2142
  property_name = canonical_property_name(property_name)
2143
- if property_name not in GENERATOR_DIRS:
2144
  return {"error": f"Unsupported property: {property_name}"}
2145
 
2146
- # STRICT: require target_value (support a few common aliases)
2147
  if data.get("target_value", None) is not None:
2148
  target_value = data["target_value"]
2149
  elif data.get("target", None) is not None:
@@ -2177,14 +2135,12 @@ class PolymerOrchestrator:
2177
 
2178
  latent_dim = int(getattr(decoder_model, "cl_emb_dim", 600))
2179
 
2180
- # choose target scaler: prefer latent_prop_model.y_scaler, fall back to scaler_y
2181
  y_scaler = getattr(latent_prop_model, "y_scaler", None)
2182
  if y_scaler is None:
2183
  y_scaler = scaler_y if scaler_y is not None else None
2184
 
2185
  tol_scaled = float(tol_scaled_override) if tol_scaled_override is not None else float(meta.get("tol_scaled", 0.5))
2186
 
2187
- # Collect seed latents from available sources:
2188
  seed_latents: List[np.ndarray] = []
2189
  cl_enc = data.get("cl_encoding", None)
2190
  if isinstance(cl_enc, dict) and isinstance(cl_enc.get("embedding"), list):
@@ -2192,7 +2148,6 @@ class PolymerOrchestrator:
2192
  if emb.shape[0] == latent_dim:
2193
  seed_latents.append(emb)
2194
 
2195
- # Optional seed pSMILES strings for biasing
2196
  seeds_str: List[str] = []
2197
  if isinstance(data.get("seed_psmiles_list"), list):
2198
  seeds_str.extend([str(x) for x in data["seed_psmiles_list"] if isinstance(x, str)])
@@ -2200,10 +2155,8 @@ class PolymerOrchestrator:
2200
  seeds_str.append(str(data["seed_psmiles"]))
2201
  if data.get("psmiles") and not seeds_str:
2202
  seeds_str.append(str(data["psmiles"]))
2203
-
2204
  seeds_str = list(dict.fromkeys(seeds_str))
2205
 
2206
- # If seed strings provided but no seed latents yet, compute CL embeddings for each seed
2207
  if seeds_str and not seed_latents:
2208
  self._ensure_cl_encoder()
2209
  for s in seeds_str:
@@ -2216,7 +2169,6 @@ class PolymerOrchestrator:
2216
  if z.shape[0] == latent_dim:
2217
  seed_latents.append(z)
2218
 
2219
- # Sample latents targeting the property
2220
  try:
2221
  Z_keep, y_s_keep, y_u_keep, target_s = self._sample_latents_for_target(
2222
  latent_prop_model=latent_prop_model,
@@ -2232,7 +2184,6 @@ class PolymerOrchestrator:
2232
  except Exception as e:
2233
  return {"error": f"Failed to sample latents conditioned on property: {e}", "paths": paths}
2234
 
2235
- # --- helpers ---
2236
  at_bracket_re = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
2237
 
2238
  def _at_to_star_bracket(s: str) -> str:
@@ -2241,7 +2192,6 @@ class PolymerOrchestrator:
2241
  return at_bracket_re.sub("[*]", s)
2242
 
2243
  def _is_rdkit_valid(psmiles: str) -> bool:
2244
- # If RDKit is unavailable, we cannot validate; treat as "valid" but flag it below.
2245
  if Chem is None:
2246
  return True
2247
  try:
@@ -2251,17 +2201,11 @@ class PolymerOrchestrator:
2251
  except Exception:
2252
  return False
2253
 
2254
- # Decode latents → pSELFIES → PSMILES; filter to RDKit-valid ONLY.
2255
- # Shortening strategy (3rd approach): generate MORE valid candidates, then keep the shortest valid K.
2256
  requested_k = int(num_samples)
2257
-
2258
- # candidates are tuples:
2259
- # (len(psmiles), abs(y_s - target_s), psmiles_out, selfies_str, y_s, y_u)
2260
  candidates: List[Tuple[int, float, str, str, float, float]] = []
2261
 
2262
- # Reuse existing knob (extra_factor) to control "generate more" without adding new API surface.
2263
  candidates_per_latent = max(1, int(extra_factor))
2264
- max_gen_rounds = 4 # best-effort retries to satisfy requested_k under RDKit validity filtering
2265
 
2266
  Z_round, y_s_round, y_u_round = Z_keep, y_s_keep, y_u_keep
2267
  for _round in range(max_gen_rounds):
@@ -2279,9 +2223,7 @@ class PolymerOrchestrator:
2279
  for selfies_str in (outs or []):
2280
  psm_raw = pselfies_to_psmiles(selfies_str)
2281
 
2282
- # Correction #1: validate FIRST on the raw returned string
2283
  if _is_rdkit_valid(psm_raw):
2284
- # Correction #2: convert [At] -> [*] AFTER validation, BEFORE response writing
2285
  psm_out = _at_to_star_bracket(psm_raw)
2286
  candidates.append(
2287
  (
@@ -2296,11 +2238,9 @@ class PolymerOrchestrator:
2296
  except Exception:
2297
  continue
2298
 
2299
- # Stop early once we have enough valid candidates to select the shortest K.
2300
  if len(candidates) >= requested_k:
2301
  break
2302
 
2303
- # If still short, resample latents and try again (best-effort; keeps validity constraints).
2304
  try:
2305
  Z_round, y_s_round, y_u_round, target_s = self._sample_latents_for_target(
2306
  latent_prop_model=latent_prop_model,
@@ -2316,11 +2256,9 @@ class PolymerOrchestrator:
2316
  except Exception:
2317
  break
2318
 
2319
- # Keep shortest valid K (tie-break by closeness to target in scaled space)
2320
  candidates.sort(key=lambda t: (t[0], t[1]))
2321
  selected = candidates[:requested_k]
2322
 
2323
- # Ensure we return as many as requested when possible (repeat shortest valid if needed).
2324
  if selected and len(selected) < requested_k:
2325
  while len(selected) < requested_k:
2326
  selected.append(selected[0])
@@ -2334,8 +2272,8 @@ class PolymerOrchestrator:
2334
  "property": property_name,
2335
  "target_value": float(target_value),
2336
  "num_samples": int(len(generated_psmiles)),
2337
- "generated_psmiles": generated_psmiles, # RDKit-valid ONLY; [At]->[*] applied after validation
2338
- "generated_selfies": selfies_raw, # aligned with generated_psmiles
2339
  "latent_property_predictions": {
2340
  "scaled": decoded_scaled,
2341
  "unscaled": decoded_unscaled,
@@ -2386,11 +2324,11 @@ class PolymerOrchestrator:
2386
  doi = normalize_doi(it.get("DOI", "")) or ""
2387
 
2388
  publisher = (it.get("publisher") or "").lower()
2389
- # Optional: exclude Brill explicitly
2390
  if doi and doi.startswith("10.1163/"):
2391
  continue
2392
  if "brill" in publisher:
2393
  continue
 
2394
  pub_year = None
2395
  if it.get("published-print") and isinstance(it["published-print"].get("date-parts"), list):
2396
  pub_year = it["published-print"]["date-parts"][0][0]
@@ -2402,7 +2340,6 @@ class PolymerOrchestrator:
2402
  doi = ""
2403
  doi_url = ""
2404
 
2405
- # Prefer DOI URL when valid; otherwise fall back to Crossref's URL field if present.
2406
  landing = (it.get("URL") or "") if isinstance(it.get("URL"), str) else ""
2407
  out.append({
2408
  "title": title,
@@ -2433,7 +2370,6 @@ class PolymerOrchestrator:
2433
  continue
2434
 
2435
  doi = normalize_doi(it.get("doi", "")) or ""
2436
- # Optional: exclude Brill explicitly
2437
  if doi and doi.startswith("10.1163/"):
2438
  continue
2439
 
@@ -2450,11 +2386,12 @@ class PolymerOrchestrator:
2450
 
2451
  out.append({
2452
  "title": it.get("title", ""),
2453
- "doi": doi, # normalized, not a URL
2454
- "url": landing or "", # prefer landing page URL
2455
  "year": it.get("publication_year") or (it.get("publication_date", "")[:4]),
2456
  "venue": (it.get("host_venue") or {}).get("display_name", ""),
2457
- "type": oa_type, "source": "OpenAlex",
 
2458
  })
2459
  return out
2460
  except Exception as e:
@@ -2648,25 +2585,16 @@ class PolymerOrchestrator:
2648
  return {"error": f"Unsupported web_search source: {src}"}
2649
 
2650
  # =============================================================================
2651
- # REPORT GENERATION (FIX for Gradio interface expectations)
2652
  # =============================================================================
2653
  def generate_report(self, data: Dict[str, Any]) -> Dict[str, Any]:
2654
- """
2655
- Minimal, interface-safe report generator used by the Gradio UI fallback.
2656
-
2657
- - Runs data_extraction -> cl_encoding -> property_prediction when possible
2658
- - Optionally runs polymer_generation if generate=True or if target_value present
2659
- - Optionally runs web_search if a query/literature_query is present
2660
- """
2661
  payload = dict(data or {})
2662
  summary: Dict[str, Any] = {}
2663
 
2664
- # Seed psmiles/property
2665
  prop = payload.get("property") or payload.get("property_name")
2666
  if prop:
2667
  payload["property"] = prop
2668
 
2669
- # NEW: infer property from questions when missing
2670
  if not payload.get("property"):
2671
  qtxt = payload.get("questions") or payload.get("question") or ""
2672
  inferred_prop = infer_property_from_text(qtxt)
@@ -2677,39 +2605,33 @@ class PolymerOrchestrator:
2677
  if psmiles:
2678
  payload["psmiles"] = psmiles
2679
 
2680
- # NEW: infer target_value from questions when missing (only useful for generation)
2681
  if payload.get("target_value", None) is None:
2682
  qtxt = payload.get("questions") or payload.get("question") or ""
2683
  inferred_tgt = infer_target_value_from_text(qtxt, payload.get("property"))
2684
  if inferred_tgt is not None:
2685
  payload["target_value"] = float(inferred_tgt)
2686
 
2687
- # 1) data_extraction
2688
  if psmiles and "data_extraction" not in payload:
2689
  ex = self._run_data_extraction({"step": -1}, payload)
2690
  payload["data_extraction"] = ex
2691
  summary["data_extraction"] = ex
2692
 
2693
- # 2) cl_encoding
2694
  if "data_extraction" in payload and "cl_encoding" not in payload:
2695
  cl = self._run_cl_encoding({"step": -1}, payload)
2696
  payload["cl_encoding"] = cl
2697
  summary["cl_encoding"] = cl
2698
 
2699
- # 3) property_prediction
2700
  if payload.get("property") and "property_prediction" not in payload:
2701
  pp = self._run_property_prediction({"step": -1}, payload)
2702
  payload["property_prediction"] = pp
2703
  summary["property_prediction"] = pp
2704
 
2705
- # 4) polymer_generation (optional)
2706
  do_gen = bool(payload.get("generate", False)) or (payload.get("target_value", None) is not None)
2707
  if do_gen and payload.get("property") and payload.get("target_value", None) is not None:
2708
  gen = self._run_polymer_generation({"step": -1}, payload)
2709
  payload["polymer_generation"] = gen
2710
  summary["generation"] = gen
2711
 
2712
- # 5) web_search (optional)
2713
  q = payload.get("query") or payload.get("literature_query")
2714
  src = payload.get("source") or "all"
2715
  if q:
@@ -2730,7 +2652,6 @@ class PolymerOrchestrator:
2730
  "questions": payload.get("questions") or payload.get("question") or "",
2731
  }
2732
 
2733
- # Add domain tags + (domain.com) cite tags, and tool tags [T#]
2734
  report = _attach_source_domains(report)
2735
  report = _index_citable_sources(report)
2736
  report = _assign_tool_tags_to_report(report)
@@ -2740,32 +2661,23 @@ class PolymerOrchestrator:
2740
  def _run_report_generation(self, step: Dict, data: Dict) -> Dict[str, Any]:
2741
  return self.generate_report(data)
2742
 
 
 
 
2743
  def compose_gpt_style_answer(
2744
  self,
2745
  report: Dict[str, Any],
2746
  case_brief: str = "",
2747
  questions: str = "",
2748
  ) -> Tuple[str, List[str]]:
2749
- """
2750
- Interface-safe composer. Uses OpenAI if available; otherwise returns a deterministic markdown.
2751
- Must return: (final_markdown, list_of_image_paths).
2752
-
2753
- Updated requirements:
2754
- - No fixed answer template: structure must follow the user's actual questions.
2755
- - Literature/web citations must be domain-style like "nature.com" (never [1], [2], ...). No parentheses.
2756
- - Tool-derived facts must cite as [T] only.
2757
- - Tool outputs should be available verbatim without tweaking (appended as JSON blocks).
2758
- """
2759
  imgs: List[str] = []
2760
 
2761
- # Ensure tags exist even if caller didn't run generate_report()
2762
  if isinstance(report, dict):
2763
  report = _attach_source_domains(report)
2764
  report = _index_citable_sources(report)
2765
  report = _assign_tool_tags_to_report(report)
2766
 
2767
  if self.openai_client is None:
2768
- # Deterministic fallback (no API dependency)
2769
  md_lines = []
2770
  if case_brief:
2771
  md_lines.append(case_brief.strip())
@@ -2780,7 +2692,6 @@ class PolymerOrchestrator:
2780
  md_lines.append(str(report))
2781
  md_lines.append("```")
2782
 
2783
- # Verbatim tool outputs (no tweaking)
2784
  verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
2785
  if verb:
2786
  md_lines.append("\n---\n\n## Tool outputs (verbatim)\n")
@@ -2788,7 +2699,6 @@ class PolymerOrchestrator:
2788
 
2789
  return "\n".join(md_lines), imgs
2790
 
2791
- # OpenAI-based synthesis
2792
  try:
2793
  prompt = (
2794
  "You are PolyAgent - consider yourself as an expert in polymer science. Answer the user's questions using ONLY the provided report.\n"
@@ -2804,12 +2714,11 @@ class PolymerOrchestrator:
2804
  "- NON-DUPLICATES: Do not repeat the same paper link. Each DOI/URL may appear at most once in the entire answer.\n"
2805
  "- Each major section should include at least 1 inline literature citation when relevant.\n"
2806
  "- Do NOT invent DOIs, URLs, titles, or sources.\n\n"
2807
- "- CITATIONS AS SPECIFIED ONLY: very strictly place each citation immediately after the claim it supports; do not add a references list.\n"
2808
  "OUTPUT RULES (STRICT):\n"
2809
  "- If a numeric value is not present in the report, write 'not available'.\n"
2810
  "- Preserve polymer endpoint tokens exactly as '[*]' in any pSMILES/SMILES shown.\n"
2811
  "- To prevent markdown mangling, put any pSMILES/SMILES inside code formatting.\n"
2812
- "- Do not rewrite or tweak any tool outputs; if you refer to them, reference them by tag (e.g., [T2]).\n\n"
2813
  f"CASE BRIEF:\n{case_brief}\n\n"
2814
  f"QUESTIONS:\n{questions}\n\n"
2815
  f"REPORT (JSON):\n{json.dumps(report, ensure_ascii=False)}\n"
@@ -2825,16 +2734,12 @@ class PolymerOrchestrator:
2825
  )
2826
  txt = resp.choices[0].message.content or ""
2827
 
2828
- # Enforce distributed inline clickable paper citations (do not touch tool citations).
2829
- # This corrects cases where the model under-cites or clusters citations.
2830
  try:
2831
  min_cites = _infer_required_citation_count(questions or "", default_n=10)
2832
  txt = _ensure_distributed_inline_citations(txt, report, min_needed=min_cites)
2833
  except Exception:
2834
  pass
2835
 
2836
-
2837
- # Enforce: DOI-URL bracket text + dedupe (each DOI/URL appears at most once)
2838
  try:
2839
  txt = _normalize_and_dedupe_literature_links(txt, report)
2840
  except Exception:
@@ -2845,23 +2750,20 @@ class PolymerOrchestrator:
2845
  except Exception:
2846
  pass
2847
 
2848
- # Always append verbatim tool outputs (no tweaking)
2849
  verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
2850
  if verb:
2851
  txt = txt.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
2852
 
2853
  return txt, imgs
2854
  except Exception as e:
2855
- # Last-resort fallback
2856
  md = f"OpenAI compose failed: {e}\n\n```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```"
2857
- # Still append verbatim tool outputs
2858
  verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
2859
  if verb:
2860
  md = md.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
2861
  return md, imgs
2862
 
2863
  # =============================================================================
2864
- # VISUAL TOOLS (PNG-only)
2865
  # =============================================================================
2866
  def _run_mol_render(self, step: Dict, data: Dict) -> Dict[str, Any]:
2867
  out_dir = Path("viz")
@@ -2915,16 +2817,6 @@ class PolymerOrchestrator:
2915
  return {"png_path": png, "n": len(mols)}
2916
 
2917
  def _run_prop_attribution(self, step: Dict, data: Dict) -> Dict[str, Any]:
2918
- """
2919
- FIXED explainability:
2920
- - Leave-one-atom-out occlusion attribution:
2921
- score_i = baseline_pred - pred(mask atom i -> wildcard)
2922
- - Highlight ONLY meaningful atoms:
2923
- * Rank by |score|
2924
- * Apply relative threshold vs max |score| (default 0.25)
2925
- * Cap by top-K
2926
- * Ensure at least 1 atom highlighted
2927
- """
2928
  out_dir = Path("viz")
2929
  out_dir.mkdir(parents=True, exist_ok=True)
2930
 
@@ -2935,11 +2827,10 @@ class PolymerOrchestrator:
2935
  prop = canonical_property_name(data.get("property") or data.get("property_name") or "glass transition")
2936
  top_k = int(data.get("top_k_atoms", data.get("top_k", 12)))
2937
 
2938
- # importance threshold controls
2939
  min_rel_importance = float(data.get("min_rel_importance", 0.25))
2940
  min_abs_importance = float(data.get("min_abs_importance", 0.0))
2941
 
2942
- if prop not in PROPERTY_HEAD_PATHS:
2943
  return {"error": f"Unsupported property for attribution: {prop}"}
2944
  if not p:
2945
  return {"error": "no psmiles"}
@@ -2960,7 +2851,6 @@ class PolymerOrchestrator:
2960
  if not isinstance(baseline, (float, int)):
2961
  return {"error": "Baseline prediction not numeric"}
2962
 
2963
- # Occlusion loop (O(N_atoms) property predictions)
2964
  scores: Dict[int, float] = {}
2965
  for idx in range(num_atoms):
2966
  try:
@@ -2968,7 +2858,7 @@ class PolymerOrchestrator:
2968
  tmp.GetAtomWithIdx(idx).SetAtomicNum(0) # wildcard
2969
  mutated = tmp.GetMol()
2970
  mut_smiles = Chem.MolToSmiles(mutated)
2971
- mut_psmiles = normalize_generated_psmiles_out(mut_smiles) # [*] -> *
2972
  except Exception:
2973
  scores[idx] = 0.0
2974
  continue
@@ -2980,7 +2870,6 @@ class PolymerOrchestrator:
2980
  else:
2981
  scores[idx] = float(baseline) - float(mut_val)
2982
 
2983
- # Select atoms: top-K by |score| but also require significance
2984
  max_abs = max((abs(v) for v in scores.values()), default=0.0)
2985
  rel_thresh = (min_rel_importance * max_abs) if max_abs > 0 else 0.0
2986
  thresh = max(float(min_abs_importance), float(rel_thresh))
@@ -2991,11 +2880,9 @@ class PolymerOrchestrator:
2991
  selected = [i for i, v in ranked if abs(v) >= thresh]
2992
  selected = selected[:k_cap]
2993
 
2994
- # Ensure at least one highlighted atom
2995
  if not selected and ranked:
2996
  selected = [ranked[0][0]]
2997
 
2998
- # Map colors (coolwarm) over selected only
2999
  atom_colors: Dict[int, tuple] = {}
3000
  sel_scores = np.array([scores[i] for i in selected], dtype=float)
3001
  if cm is not None and sel_scores.size > 0:
@@ -3043,7 +2930,6 @@ class PolymerOrchestrator:
3043
  except Exception as e:
3044
  return {"error": f"prop_attribution rendering failed: {e}"}
3045
 
3046
- # convenience
3047
  def process_query(self, user_query: str, user_inputs: Dict[str, Any] = None) -> Dict[str, Any]:
3048
  plan = self.analyze_query(user_query)
3049
  results = self.execute_plan(plan, user_inputs)
@@ -3051,6 +2937,6 @@ class PolymerOrchestrator:
3051
 
3052
 
3053
  if __name__ == "__main__":
3054
- cfg = OrchestratorConfig()
3055
  orch = PolymerOrchestrator(cfg)
3056
- print("PolymerOrchestrator ready (5M heads + 5M inverse-design + LLM tool-calling planner + occlusion explainability).")
 
1
+ """
2
+ PolyAgent Orchestrator (5M)
3
+ ===========================
4
+
5
+ This file provides a modular orchestrator that:
6
+ - extracts polymer multimodal data (graph/geometry/fingerprints/PSMILES)
7
+ - encodes CL embeddings using PolyFusion encoders
8
+ - predicts single properties using best downstream heads
9
+ - performs inverse design using a CL-conditioned SELFIES-TED generator
10
+ - retrieves literature via local RAG + web APIs
11
+ - visualizes polymer renderings and explainability maps
12
+ - composes a final response along with verbatim tool outputs
13
+ """
14
+
15
  import os
16
  import re
17
  import json
 
20
  from pathlib import Path
21
  from typing import Dict, Any, List, Optional, Tuple
22
  from urllib.parse import urlparse
23
+
24
  import numpy as np
25
  import torch
26
  import torch.nn as nn
 
61
  except Exception:
62
  spm = None
63
 
64
+ # Optional: selfies (for SELFIES→SMILES/PSMILES conversion)
65
  try:
66
  import selfies as sf
67
  except Exception:
 
71
  SELFIES_AVAILABLE = sf is not None
72
 
73
 
74
+ # =============================================================================
75
+ # PATHS / CONFIGURATION
76
+ # =============================================================================
77
+ class PathsConfig:
78
+ """
79
+ Centralized path placeholders. Replace these with your local paths.
80
+ """
81
+ # CL weights
82
+ cl_weights_path = "/path/to/multimodal_output_5M/best/pytorch_model.bin"
83
+
84
+ # Chroma DB (local RAG vectorstore persist dir)
85
+ chroma_db_path = "/path/to/chroma_polymer_db_big"
86
+
87
+ # SentencePiece model + vocab
88
+ spm_model_path = "/path/to/spm_5M.model"
89
+ spm_vocab_path = "/path/to/spm_5M.vocab"
90
+
91
+ # Downstream bestweights directory produced by your 5M downstream script
92
+ downstream_bestweights_5m_dir = "/path/to/multimodal_downstream_bestweights_5M"
93
+
94
+ # Inverse-design generator artifacts directory produced by your 5M inverse design run
95
+ inverse_design_5m_dir = "/path/to/multimodal_inverse_design_output_5M_polybart_style/best_models"
96
+
97
+
98
  # =============================================================================
99
  # DOI NORMALIZATION / RESOLUTION HELPERS
100
  # =============================================================================
101
  _DOI_RE = re.compile(r"^10\.\d{4,9}/\S+$", re.IGNORECASE)
102
 
103
+
104
  def normalize_doi(raw: str) -> Optional[str]:
105
  if not isinstance(raw, str):
106
  return None
 
114
  s = s.rstrip(").,;]}")
115
  return s if _DOI_RE.match(s) else None
116
 
117
+
118
  def doi_to_url(doi: str) -> str:
119
  # doi is assumed normalized
120
  return f"https://doi.org/{doi}"
121
 
122
+
123
  def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
124
  """
125
  Best-effort resolver check. Keeps pipeline robust against dead/unregistered DOIs.
 
136
  except Exception:
137
  return False
138
 
139
+
140
  # =============================================================================
141
+ # CITATION / DOMAIN TAGGING HELPERS
142
  # =============================================================================
143
  def _url_to_domain(url: str) -> Optional[str]:
144
  if not isinstance(url, str) or not url.strip():
 
175
  except Exception:
176
  return None
177
 
178
+
179
  def _attach_source_domains(obj: Any) -> Any:
180
  """
181
  Recursively add a short source_domain field where URLs are present.
182
+ This enables domain-style citations like "(nature.com)" (note: the composer
183
+ later enforces DOI-URL bracket citations for papers).
184
  """
185
  if isinstance(obj, list):
186
  return [_attach_source_domains(x) for x in obj]
 
204
  def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
205
  """
206
  Add 'cite_tag' fields for citable web/RAG items using DOI-first URL tags.
207
+
208
  Requirement:
209
  - Paper citations must use the COMPLETE DOI URL (https://doi.org/...) as the bracket text.
210
  - If DOI is not available, fall back to the best http(s) URL.
 
226
  return False
227
 
228
  def get_best_url(d: Dict[str, Any]) -> Optional[str]:
229
+ # DOI-first
230
  doi = normalize_doi(d.get("doi", ""))
231
  if doi:
232
  return doi_to_url(doi)
 
244
  out = {k: walk_and_tag(v) for k, v in node.items()}
245
 
246
  if is_citable_item(out):
 
247
  url = get_best_url(out)
248
  if isinstance(url, str) and url.startswith(("http://", "https://")):
 
249
  cur = out.get("cite_tag")
250
  if not (isinstance(cur, str) and cur.strip().startswith(("http://", "https://"))):
251
  out["cite_tag"] = url.strip()
 
 
 
252
 
 
253
  url = get_best_url(out)
254
  dom = out.get("source_domain") or (_url_to_domain(url) if url else None) or "source"
255
  citation_index["sources"].append(
256
  {
 
257
  "tag": out.get("cite_tag") if isinstance(out.get("cite_tag"), str) else url,
258
  "domain": dom,
259
  "title": out.get("title") or out.get("name") or "Untitled",
 
274
  return report
275
 
276
 
277
+ # =============================================================================
278
+ # INLINE CITATION ENFORCERS (distributed, deduped, DOI-first)
279
+ # =============================================================================
 
280
  _CITE_COUNT_PATTERNS = [
281
  r"(?:at\s+least\s+)?(\d{1,3})\s*(?:citations|citation|papers|paper|sources|source|references|reference)\b",
282
  r"\bcite\s+(\d{1,3})\s*(?:papers|paper|sources|source|references|reference|citations|citation)\b",
 
299
 
300
  def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[str, str]]:
301
  """
302
+ Return unique (cite_text, url) pairs from report['citation_index']['sources'].
303
+ cite_text is strictly the DOI URL (preferred) or URL fallback.
304
  """
305
  out: List[Tuple[str, str]] = []
306
  seen: set = set()
 
317
  url = s.get("url")
318
  if not isinstance(url, str) or not url.startswith(("http://", "https://")):
319
  continue
 
320
  cite_text = s.get("tag") if isinstance(s.get("tag"), str) and s.get("tag").strip() else url
321
  if not isinstance(cite_text, str) or not cite_text.strip():
322
  cite_text = url
 
346
  if not citations:
347
  return md
348
 
 
349
  existing_urls = set(re.findall(r"\[[^\]]+\]\((https?://[^)]+)\)", md))
350
  need = max(0, int(min_needed) - len(existing_urls))
351
  if need <= 0:
352
  return md
353
 
 
354
  remaining: List[Tuple[str, str]] = [(d, u) for (d, u) in citations if u not in existing_urls]
355
  if not remaining:
356
  return md
357
 
 
358
  parts = re.split(r"(```[\s\S]*?```)", md)
359
  rem_i = 0
360
 
 
364
  if part.startswith("```") and part.endswith("```"):
365
  continue
366
 
 
367
  segs = re.split(r"(\n\s*\n)", part)
368
  for si in range(0, len(segs), 2):
369
  if rem_i >= len(remaining) or need <= 0:
 
371
  para = segs[si]
372
  if not isinstance(para, str) or not para.strip():
373
  continue
 
374
  if para.lstrip().startswith("#"):
375
  continue
 
376
  if re.search(r"\[[^\]]+\]\((https?://[^)]+)\)", para):
377
  continue
378
+ if not re.search(
379
+ r"\b(reported|shown|demonstrated|study|studies|literature|evidence|review|according)\b",
380
+ para,
381
+ flags=re.IGNORECASE,
382
+ ):
383
  continue
384
 
385
  cite_text, url = remaining[rem_i]
 
386
  segs[si] = para.rstrip() + f" [{cite_text}]({url})"
387
  rem_i += 1
388
  need -= 1
389
 
390
  parts[pi] = "".join(segs)
391
 
 
392
  if need > 0 and rem_i < len(remaining):
393
  md2 = "".join(parts)
394
  parts2 = re.split(r"(```[\s\S]*?```)", md2)
 
420
 
421
  def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> str:
422
  """
423
+ Enforce:
424
+ - Link text must be COMPLETE DOI URL (preferred) or URL fallback.
425
+ - Each DOI/URL appears at most once in the answer.
426
  Only operates outside fenced code blocks.
427
  """
428
  if not isinstance(md, str) or not md.strip():
 
430
  if not isinstance(report, dict):
431
  return md
432
 
 
433
  url_to_text: Dict[str, str] = {}
434
  ci = report.get("citation_index", {})
435
  sources = ci.get("sources") if isinstance(ci, dict) else None
 
449
 
450
  def _rewrite_and_dedupe(text: str) -> str:
451
  def repl(m: re.Match) -> str:
 
452
  url = m.group(2).strip()
453
  if url in seen_urls:
 
454
  return ""
455
  seen_urls.add(url)
456
  pref = url_to_text.get(url, url)
457
  return f"[{pref}]({url})"
458
+
459
  return re.sub(r"\[([^\]]+)\]\((https?://[^)]+)\)", repl, text)
460
 
461
  for i, part in enumerate(parts):
462
  if part.startswith("```") and part.endswith("```"):
463
  continue
464
  parts[i] = _rewrite_and_dedupe(part)
 
465
  parts[i] = re.sub(r"[ \t]{2,}", " ", parts[i])
466
  parts[i] = re.sub(r"\n{3,}", "\n\n", parts[i])
467
 
 
471
  def autolink_doi_urls(md: str) -> str:
472
  """
473
  Wrap bare DOI URLs in Markdown links outside code blocks.
 
474
  """
475
  if not md:
476
  return md
 
481
  parts[i] = re.sub(
482
  r"(?<!\]\()(?P<u>https?://doi\.org/10\.\d{4,9}/[^\s\)\],;]+)",
483
  lambda m: f"[{m.group('u')}]({m.group('u')})",
484
+ part,
485
  flags=re.IGNORECASE,
486
  )
487
  return "".join(parts)
488
 
489
+
490
+ # =============================================================================
491
+ # TOOL TAGS + VERBATIM TOOL OUTPUT RENDERER
492
+ # =============================================================================
493
  def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
494
  """
495
+ Ensure each tool output has a [T] cite tag.
 
496
  """
497
  if not isinstance(report, dict):
498
  return report
 
501
  if not isinstance(tool_outputs, dict):
502
  return report
503
 
504
+ preferred = [
 
505
  "data_extraction",
506
  "cl_encoding",
507
  "property_prediction",
 
511
  "report_generation",
512
  ]
513
 
 
514
  tool_tag_map: Dict[str, str] = {}
515
  tag = "[T]"
516
 
517
+ for tool in preferred:
 
518
  node = tool_outputs.get(tool)
519
  if node is None:
520
  continue
 
522
  if isinstance(node, dict) and not node.get("cite_tag"):
523
  node["cite_tag"] = tag
524
 
 
525
  for tool, node in tool_outputs.items():
526
  if tool in tool_tag_map or node is None:
527
  continue
 
529
  if isinstance(node, dict) and not node.get("cite_tag"):
530
  node["cite_tag"] = tag
531
 
 
532
  try:
533
  summary = report.get("summary", {}) or {}
534
  if isinstance(summary, dict):
 
535
  key_to_tool = {
536
  "data_extraction": "data_extraction",
537
  "cl_encoding": "cl_encoding",
 
555
 
556
  def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
557
  """
558
+ Render tool outputs as verbatim JSON blocks (no content rewriting).
559
  """
560
  if not isinstance(report, dict):
561
  return ""
 
564
  if not isinstance(tool_outputs, dict):
565
  return ""
566
 
 
567
  preferred = [
568
  "data_extraction",
569
  "cl_encoding",
 
591
 
592
 
593
  # =============================================================================
594
+ # PICKLE / JOBLIB COMPATIBILITY SHIMS
595
  # =============================================================================
596
  class LatentPropertyModel:
597
  """
598
  Compatibility shim for joblib/pickle artifacts saved with references like:
599
  __main__.LatentPropertyModel
 
 
 
 
 
 
 
600
  """
601
  def predict(self, X):
 
602
  for attr in ("model", "gpr", "gpr_model", "estimator", "predictor", "_model", "_gpr"):
603
  if hasattr(self, attr):
604
  obj = getattr(self, attr)
605
  if hasattr(obj, "predict"):
606
  return obj.predict(X)
 
607
  raise AttributeError(
608
  "LatentPropertyModel shim could not find an underlying predictor. "
609
  "Artifact expects a wrapped model attribute with a .predict method."
 
613
  def _install_unpickle_shims() -> None:
614
  """
615
  Ensure that any classes pickled under __main__ are available at load time.
 
616
  """
617
  main_mod = sys.modules.get("__main__")
618
  if main_mod is not None and not hasattr(main_mod, "LatentPropertyModel"):
 
630
  return joblib.load(path)
631
  except Exception as e:
632
  msg = str(e)
 
633
  if "Can't get attribute 'LatentPropertyModel' on <module '__main__'" in msg:
634
  _install_unpickle_shims()
635
  return joblib.load(path)
 
637
 
638
 
639
  # =============================================================================
640
+ # PROPERTY + GENERATOR REGISTRY
641
  # =============================================================================
642
+ def build_property_registries(paths: PathsConfig):
643
+ """
644
+ Build registry dicts for:
645
+ - downstream property heads (checkpoint + metadata)
646
+ - inverse-design generator directories
647
+ """
648
+ downstream = paths.downstream_bestweights_5m_dir
649
+ invgen = paths.inverse_design_5m_dir
650
+
651
+ PROPERTY_HEAD_PATHS = {
652
+ "density": os.path.join(downstream, "density", "best_run_checkpoint.pt"),
653
+ "glass transition": os.path.join(downstream, "glass_transition", "best_run_checkpoint.pt"),
654
+ "melting": os.path.join(downstream, "melting", "best_run_checkpoint.pt"),
655
+ "specific volume": os.path.join(downstream, "specific_volume", "best_run_checkpoint.pt"),
656
+ "thermal decomposition": os.path.join(downstream, "thermal_decomposition", "best_run_checkpoint.pt"),
657
+ }
658
+
659
+ PROPERTY_HEAD_META = {
660
+ "density": os.path.join(downstream, "density", "best_run_metadata.json"),
661
+ "glass transition": os.path.join(downstream, "glass_transition", "best_run_metadata.json"),
662
+ "melting": os.path.join(downstream, "melting", "best_run_metadata.json"),
663
+ "specific volume": os.path.join(downstream, "specific_volume", "best_run_metadata.json"),
664
+ "thermal decomposition": os.path.join(downstream, "thermal_decomposition", "best_run_metadata.json"),
665
+ }
666
+
667
+ GENERATOR_DIRS = {
668
+ "density": os.path.join(invgen, "density"),
669
+ "glass transition": os.path.join(invgen, "glass_transition"),
670
+ "melting": os.path.join(invgen, "melting"),
671
+ "specific volume": os.path.join(invgen, "specific_volume"),
672
+ "thermal decomposition": os.path.join(invgen, "thermal_decomposition"),
673
+ }
674
+
675
+ return PROPERTY_HEAD_PATHS, PROPERTY_HEAD_META, GENERATOR_DIRS
676
 
677
 
678
  # =============================================================================
679
+ # Property name canonicalization + inference helpers
680
  # =============================================================================
681
  def canonical_property_name(name: str) -> str:
682
  """
683
+ Map user/tool inputs to the canonical keys used in registries.
684
  """
685
  if not isinstance(name, str):
686
  return ""
 
704
  return aliases.get(s, s)
705
 
706
 
 
 
 
 
707
  _NUM_RE = r"[-+]?\d+(?:\.\d+)?"
708
 
709
+
710
  def infer_property_from_text(text: str) -> Optional[str]:
711
  s = (text or "").lower()
 
712
  m = re.search(r"\bproperty\b\s*[:=]\s*([a-zA-Z _-]+)", s)
713
  if m:
714
  cand = m.group(1).strip().lower()
 
735
  return "density"
736
  return None
737
 
738
+
739
  def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[float]:
740
  sl = (text or "").lower()
741
 
 
767
  except Exception:
768
  pass
769
 
 
770
  tokens = []
771
  if prop == "glass transition":
772
  tokens = ["tg", "glass transition"]
 
791
 
792
  return None
793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
794
 
795
  # =============================================================================
796
+ # Tokenizers
797
  # =============================================================================
798
  class SimpleCharTokenizer:
799
  def __init__(self, vocab_chars: List[str], special_tokens=("<pad>", "<s>", "</s>", "<unk>")):
 
853
  blocked.append(tid)
854
  setattr(self, "_blocked_ids", blocked)
855
 
 
856
  if self.PieceToId("*") is None:
857
  raise RuntimeError("SentencePiece tokenizer loaded but '*' token not found – aborting for safe PSMILES generation.")
858
 
 
890
  s = re.sub(r"\*", "[*]", s)
891
  return s
892
 
893
+
894
  _AT_BRACKET_UI_RE = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
895
 
896
+
897
  def replace_at_with_star(psmiles: str) -> str:
898
  if not isinstance(psmiles, str) or not psmiles:
899
  return psmiles
900
  return _AT_BRACKET_UI_RE.sub("[*]", psmiles)
901
 
902
+
903
  # =============================================================================
904
+ # SELFIES utilities
905
  # =============================================================================
906
  _SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]")
907
 
 
947
  def pselfies_to_psmiles(selfies_str: str) -> str:
948
  """
949
  For this orchestrator we treat pSELFIES→PSMILES as SELFIES→canonical SMILES.
 
 
 
950
  """
951
  return selfies_to_smiles(selfies_str)
952
 
953
 
954
  # =============================================================================
955
+ # SELFIES-TED decoder
956
  # =============================================================================
957
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
958
  SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted")
959
 
 
960
  GEN_MAX_LEN = 256
961
  GEN_MIN_LEN = 10
962
  GEN_TOP_P = 0.92
963
  GEN_TEMPERATURE = 1.0
964
  GEN_REPETITION_PENALTY = 1.05
965
+ LATENT_NOISE_STD_GEN = 0.15
966
 
967
 
968
  def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
 
 
 
969
  import time
970
  last_err = None
971
  for t in range(max_tries):
 
981
 
982
  def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
983
  """
984
+ Load tokenizer + seq2seq model for SELFIES-TED.
985
  """
986
  def _load_tok():
987
  return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True)
 
996
 
997
  class CLConditionedSelfiesTEDGenerator(nn.Module):
998
  """
999
+ CL embedding (latent) -> fixed-length memory -> conditions SELFIES-TED seq2seq.
 
1000
  """
1001
  def __init__(self, tok, seq2seq_model, cl_emb_dim: int = 600, mem_len: int = 4):
1002
  super().__init__()
 
1044
  temperature: float = GEN_TEMPERATURE,
1045
  repetition_penalty: float = GEN_REPETITION_PENALTY,
1046
  ) -> List[str]:
 
 
 
1047
  self.eval()
1048
  z = z.to(next(self.parameters()).device)
1049
  enc_out, attn = self.build_encoder_outputs(z)
 
1066
 
1067
 
1068
  # =============================================================================
1069
+ # Latent -> property helper
1070
  # =============================================================================
1071
  def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
 
 
 
 
 
 
1072
  z_use = np.asarray(z, dtype=np.float32)
1073
  if z_use.ndim == 1:
1074
  z_use = z_use.reshape(1, -1)
1075
 
 
1076
  pca = getattr(latent_model, "pca", None)
1077
  if pca is not None:
1078
  z_use = pca.transform(z_use.astype(np.float32))
1079
 
 
1080
  gpr = getattr(latent_model, "gpr", None)
1081
  if gpr is not None and hasattr(gpr, "predict"):
1082
  y_s = gpr.predict(z_use)
 
1087
 
1088
  y_s = np.array(y_s, dtype=np.float32).reshape(-1)
1089
 
 
1090
  y_scaler = getattr(latent_model, "y_scaler", None)
1091
  if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
1092
  y_u = y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1)
 
1097
 
1098
 
1099
  # =============================================================================
1100
+ # Legacy models
1101
  # =============================================================================
1102
  class TransformerDecoderOnly(nn.Module):
1103
  def __init__(
 
1191
 
1192
 
1193
  # =============================================================================
1194
+ # Orchestrator config
1195
  # =============================================================================
1196
  class OrchestratorConfig:
1197
+ def __init__(self, paths: Optional[PathsConfig] = None):
1198
+ self.paths = paths or PathsConfig()
1199
+
1200
  self.base_dir = "."
1201
+ self.cl_weights_path = self.paths.cl_weights_path
1202
+ self.chroma_db_path = self.paths.chroma_db_path
1203
  self.rag_embedding_model = "text-embedding-3-small"
1204
 
1205
  self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
1206
  self.model = os.getenv("OPENAI_MODEL", "gpt-4.1")
1207
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1208
 
1209
+ self.spm_model_path = self.paths.spm_model_path
1210
+ self.spm_vocab_path = self.paths.spm_vocab_path
1211
 
1212
  self.springer_api_key = os.getenv("SPRINGER_NATURE_API_KEY", "")
1213
  self.semantic_scholar_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
 
1219
  "property_prediction": True,
1220
  "polymer_generation": True,
1221
  "web_search": True,
1222
+ "report_generation": True, # required by UI
1223
  "mol_render": True,
1224
  "gen_grid": True,
1225
  "prop_attribution": True,
 
1264
  "CrossRef, OpenAlex, EuropePMC, arXiv, Semantic Scholar, Springer Nature (API key), Internet Archive"
1265
  ),
1266
  },
1267
+ "report_generation": {
1268
  "name": "Report Generation",
1269
  "description": (
1270
  "Synthesizes available tool outputs into a single structured report object "
 
1295
  class PolymerOrchestrator:
1296
  def __init__(self, config: OrchestratorConfig):
1297
  self.config = config
1298
+
1299
+ # Build registries from placeholders (no behavior change; just centralization)
1300
+ self.PROPERTY_HEAD_PATHS, self.PROPERTY_HEAD_META, self.GENERATOR_DIRS = build_property_registries(self.config.paths)
1301
+
1302
  self._openai_client = None
1303
  self._openai_unavailable_reason = None
1304
  self._data_extractor = None
 
1315
 
1316
  self.system_prompt = self._build_system_prompt()
1317
 
1318
+ # -------------------------------------------------------------------------
1319
+ # OpenAI client
1320
+ # -------------------------------------------------------------------------
1321
  @property
1322
  def openai_client(self):
1323
  if self._openai_client is None:
 
1339
  return (
1340
  "You are the tool-planning module for **PolyAgent**, a polymer-science agent.\n"
1341
  "Your job is to inspect the user's questions and decide which tools\n"
1342
+ "to run in which order.\n\n"
1343
  "Critical tool dependencies:\n"
1344
  "- property_prediction should run AFTER cl_encoding when possible and should reuse cl_encoding.embedding.\n"
1345
  "- polymer_generation is inverse-design and REQUIRES target_value (property -> PSMILES).\n\n"
 
1348
  )
1349
 
1350
  # =============================================================================
1351
+ # Planner: LLM tool-calling
1352
  # =============================================================================
1353
  def analyze_query(self, user_query: str) -> Dict[str, Any]:
1354
  schema_keys = ["analysis", "tools_required", "execution_plan"]
 
1397
  }
1398
  }
1399
 
 
1400
  try:
1401
  response = self.openai_client.chat.completions.create(
1402
  model=self.config.model,
 
1422
 
1423
  raise RuntimeError("Tool-calling plan not returned; falling back to JSON mode.")
1424
  except Exception:
 
1425
  try:
1426
  response = self.openai_client.chat.completions.create(
1427
  model=self.config.model,
 
1467
  output = self._run_polymer_generation(step, intermediate_data)
1468
  elif tool_name == "web_search":
1469
  output = self._run_web_search(step, intermediate_data)
1470
+ elif tool_name == "report_generation":
1471
  output = self._run_report_generation(step, intermediate_data)
1472
  elif tool_name == "mol_render":
1473
  output = self._run_mol_render(step, intermediate_data)
 
1581
  "year": meta.get("year", ""),
1582
  "source": meta.get("source", meta.get("source_path", "")),
1583
  "venue": meta.get("venue", meta.get("journal", "")),
 
1584
  "url": meta.get("url") or meta.get("link") or meta.get("href") or "",
1585
  "doi": meta.get("doi") or "",
1586
  })
 
1705
  "attention_mask": torch.ones(1, 2048, dtype=torch.bool, device=self.config.device),
1706
  }
1707
 
1708
+ # psmiles encoder input
1709
  if self._psmiles_tokenizer is None:
1710
  try:
1711
+ from PolyFusion.DeBERTav2 import build_psmiles_tokenizer
1712
  self._psmiles_tokenizer = build_psmiles_tokenizer()
1713
  except Exception:
1714
  self._psmiles_tokenizer = None
 
1735
  with torch.no_grad():
1736
  embeddings_dict = self._cl_encoder.encode(batch_mods)
1737
 
 
1738
  required_modalities = ("gine", "schnet", "fp", "psmiles")
1739
  missing = [m for m in required_modalities if m not in embeddings_dict]
1740
  if missing:
 
1757
  import torch.nn as nn
1758
 
1759
  property_name = canonical_property_name(property_name)
1760
+ prop_ckpt = self.PROPERTY_HEAD_PATHS.get(property_name)
1761
+ prop_meta = self.PROPERTY_HEAD_META.get(property_name)
1762
 
1763
  if prop_ckpt is None:
1764
  raise ValueError(f"No property head registered for: {property_name}")
 
1778
 
1779
  ckpt = torch.load(prop_ckpt, map_location=self.config.device, weights_only=False)
1780
 
 
1781
  state_dict = None
1782
  for k in ("state_dict", "model_state_dict", "model_state", "head_state_dict", "regressor_state_dict"):
1783
  if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
 
1803
 
1804
  head = RegressionHeadOnly(hidden_dim=600, dropout=float(meta.get("dropout", 0.1))).to(self.config.device)
1805
 
 
1806
  normalized = {}
1807
  for k, v in state_dict.items():
1808
  nk = k
 
1823
  head.load_state_dict(normalized, strict=False)
1824
  head.eval()
1825
 
 
1826
  y_scaler = None
1827
  if isinstance(ckpt, dict):
1828
  for sk in ("y_scaler", "scaler_y", "target_scaler", "y_normalizer"):
 
1849
  return {"error": "Specify property name"}
1850
 
1851
  property_name = canonical_property_name(property_name)
1852
+ if property_name not in self.PROPERTY_HEAD_PATHS:
1853
  return {"error": f"Unsupported property: {property_name}"}
1854
 
 
1855
  emb_from_cl = None
1856
  cl = data.get("cl_encoding", None)
1857
  if isinstance(cl, dict) and isinstance(cl.get("embedding"), list) and len(cl["embedding"]) == 600:
1858
  emb_from_cl = torch.tensor([cl["embedding"]], dtype=torch.float32, device=self.config.device)
1859
 
 
1860
  multimodal = data.get("data_extraction", None)
1861
  psmiles = data.get("psmiles", data.get("smiles", None))
1862
  if emb_from_cl is None:
 
1874
  with torch.no_grad():
1875
  embs = self._cl_encoder.encode(batch_mods)
1876
 
 
1877
  required_modalities = ("gine", "schnet", "fp", "psmiles")
1878
  missing = [m for m in required_modalities if m not in embs]
1879
  if missing:
1880
  return {"error": f"CL encoder did not return embeddings for modalities: {', '.join(missing)}"}
1881
 
1882
  all_embs = [embs[k] for k in required_modalities]
1883
+ emb_from_cl = torch.stack(all_embs, dim=0).mean(dim=0)
1884
  except Exception as e:
1885
  return {"error": f"Failed to compute CL embedding: {e}"}
1886
 
 
1887
  try:
1888
  head, y_scaler, meta, ckpt_path = self._load_property_head(property_name)
1889
  with torch.no_grad():
 
1891
 
1892
  pred_value = float(pred_norm)
1893
 
 
1894
  if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
1895
  try:
1896
  inv = y_scaler.inverse_transform(np.array([[pred_norm]], dtype=float))
1897
  pred_value = float(inv[0][0])
1898
  except Exception:
1899
  pred_value = float(pred_norm)
 
 
1900
  else:
1901
  mean = (meta or {}).get("scaler_mean", None)
1902
  scale = (meta or {}).get("scaler_scale", None)
 
 
1903
  try:
1904
  if isinstance(mean, list) and isinstance(scale, list) and len(mean) == 1 and len(scale) == 1:
1905
  pred_value = float(pred_norm) * float(scale[0]) + float(mean[0])
1906
  except Exception:
1907
  pred_value = float(pred_norm)
1908
 
 
1909
  out_psmiles = None
1910
  if isinstance(multimodal, dict):
1911
  out_psmiles = multimodal.get("canonical_psmiles")
 
1920
  "predictions": {property_name: pred_value},
1921
  "prediction_normalized": float(pred_norm),
1922
  "head_checkpoint_path": ckpt_path,
1923
+ "metadata_path": self.PROPERTY_HEAD_META.get(property_name, ""),
1924
  "normalization_applied": bool(
1925
  (y_scaler is not None and hasattr(y_scaler, "inverse_transform")) or
1926
  ((meta or {}).get("scaler_mean") is not None and (meta or {}).get("scaler_scale") is not None)
 
1930
  except Exception as e:
1931
  return {"error": f"Property prediction failed: {e}"}
1932
 
1933
+ # ----------------- Inverse design generator (CL + SELFIES-TED) ----------------- #
1934
  def _get_selfies_ted_backend(self, model_name: str) -> Tuple[Any, Any]:
 
 
 
1935
  if not model_name:
1936
  model_name = SELFIES_TED_MODEL_NAME
1937
  if model_name in self._selfies_ted_cache:
 
1942
  return tok, model
1943
 
1944
  def _load_property_generator(self, property_name: str):
 
 
 
 
 
 
 
1945
  property_name = canonical_property_name(property_name)
1946
  if property_name in self._property_generators:
1947
  return self._property_generators[property_name]
1948
 
1949
+ base_dir = self.GENERATOR_DIRS.get(property_name)
1950
  if base_dir is None:
1951
  raise ValueError(f"No generator registered for: {property_name}")
1952
  if not os.path.isdir(base_dir):
 
1994
  if not gpr_path or not os.path.exists(gpr_path):
1995
  raise FileNotFoundError(f"GPR *.joblib not found in {base_dir}")
1996
 
 
1997
  _install_unpickle_shims()
1998
+ scaler_y = _safe_joblib_load(scaler_path)
1999
+ latent_prop_model = _safe_joblib_load(gpr_path)
2000
 
 
2001
  selfies_ted_name = meta.get("selfies_ted_model", SELFIES_TED_MODEL_NAME)
2002
  tok, selfies_backbone = self._get_selfies_ted_backend(selfies_ted_name)
2003
 
 
2012
  ).to(self.config.device)
2013
 
2014
  ckpt = torch.load(decoder_path, map_location=self.config.device, weights_only=False)
 
2015
  state_dict = None
2016
  if isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
2017
  state_dict = ckpt
 
2051
  latent_noise_std: float = LATENT_NOISE_STD_GEN,
2052
  extra_factor: int = 8,
2053
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
 
 
 
 
 
 
2054
  def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
2055
  n = np.linalg.norm(x, axis=-1, keepdims=True)
2056
  return x / np.clip(n, eps, None)
2057
 
 
2058
  if y_scaler is not None and hasattr(y_scaler, "transform"):
2059
  target_s = float(y_scaler.transform(np.array([[target_value]], dtype=np.float32))[0, 0])
2060
  else:
 
2094
 
2095
  @torch.no_grad()
2096
  def _run_polymer_generation(self, step: Dict, data: Dict) -> Dict:
 
 
 
 
 
 
 
 
2097
  property_name = data.get("property", data.get("property_name", None))
2098
  if property_name is None:
2099
  return {"error": "Specify property name for generation"}
2100
 
2101
  property_name = canonical_property_name(property_name)
2102
+ if property_name not in self.GENERATOR_DIRS:
2103
  return {"error": f"Unsupported property: {property_name}"}
2104
 
 
2105
  if data.get("target_value", None) is not None:
2106
  target_value = data["target_value"]
2107
  elif data.get("target", None) is not None:
 
2135
 
2136
  latent_dim = int(getattr(decoder_model, "cl_emb_dim", 600))
2137
 
 
2138
  y_scaler = getattr(latent_prop_model, "y_scaler", None)
2139
  if y_scaler is None:
2140
  y_scaler = scaler_y if scaler_y is not None else None
2141
 
2142
  tol_scaled = float(tol_scaled_override) if tol_scaled_override is not None else float(meta.get("tol_scaled", 0.5))
2143
 
 
2144
  seed_latents: List[np.ndarray] = []
2145
  cl_enc = data.get("cl_encoding", None)
2146
  if isinstance(cl_enc, dict) and isinstance(cl_enc.get("embedding"), list):
 
2148
  if emb.shape[0] == latent_dim:
2149
  seed_latents.append(emb)
2150
 
 
2151
  seeds_str: List[str] = []
2152
  if isinstance(data.get("seed_psmiles_list"), list):
2153
  seeds_str.extend([str(x) for x in data["seed_psmiles_list"] if isinstance(x, str)])
 
2155
  seeds_str.append(str(data["seed_psmiles"]))
2156
  if data.get("psmiles") and not seeds_str:
2157
  seeds_str.append(str(data["psmiles"]))
 
2158
  seeds_str = list(dict.fromkeys(seeds_str))
2159
 
 
2160
  if seeds_str and not seed_latents:
2161
  self._ensure_cl_encoder()
2162
  for s in seeds_str:
 
2169
  if z.shape[0] == latent_dim:
2170
  seed_latents.append(z)
2171
 
 
2172
  try:
2173
  Z_keep, y_s_keep, y_u_keep, target_s = self._sample_latents_for_target(
2174
  latent_prop_model=latent_prop_model,
 
2184
  except Exception as e:
2185
  return {"error": f"Failed to sample latents conditioned on property: {e}", "paths": paths}
2186
 
 
2187
  at_bracket_re = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
2188
 
2189
  def _at_to_star_bracket(s: str) -> str:
 
2192
  return at_bracket_re.sub("[*]", s)
2193
 
2194
  def _is_rdkit_valid(psmiles: str) -> bool:
 
2195
  if Chem is None:
2196
  return True
2197
  try:
 
2201
  except Exception:
2202
  return False
2203
 
 
 
2204
  requested_k = int(num_samples)
 
 
 
2205
  candidates: List[Tuple[int, float, str, str, float, float]] = []
2206
 
 
2207
  candidates_per_latent = max(1, int(extra_factor))
2208
+ max_gen_rounds = 4
2209
 
2210
  Z_round, y_s_round, y_u_round = Z_keep, y_s_keep, y_u_keep
2211
  for _round in range(max_gen_rounds):
 
2223
  for selfies_str in (outs or []):
2224
  psm_raw = pselfies_to_psmiles(selfies_str)
2225
 
 
2226
  if _is_rdkit_valid(psm_raw):
 
2227
  psm_out = _at_to_star_bracket(psm_raw)
2228
  candidates.append(
2229
  (
 
2238
  except Exception:
2239
  continue
2240
 
 
2241
  if len(candidates) >= requested_k:
2242
  break
2243
 
 
2244
  try:
2245
  Z_round, y_s_round, y_u_round, target_s = self._sample_latents_for_target(
2246
  latent_prop_model=latent_prop_model,
 
2256
  except Exception:
2257
  break
2258
 
 
2259
  candidates.sort(key=lambda t: (t[0], t[1]))
2260
  selected = candidates[:requested_k]
2261
 
 
2262
  if selected and len(selected) < requested_k:
2263
  while len(selected) < requested_k:
2264
  selected.append(selected[0])
 
2272
  "property": property_name,
2273
  "target_value": float(target_value),
2274
  "num_samples": int(len(generated_psmiles)),
2275
+ "generated_psmiles": generated_psmiles,
2276
+ "generated_selfies": selfies_raw,
2277
  "latent_property_predictions": {
2278
  "scaled": decoded_scaled,
2279
  "unscaled": decoded_unscaled,
 
2324
  doi = normalize_doi(it.get("DOI", "")) or ""
2325
 
2326
  publisher = (it.get("publisher") or "").lower()
 
2327
  if doi and doi.startswith("10.1163/"):
2328
  continue
2329
  if "brill" in publisher:
2330
  continue
2331
+
2332
  pub_year = None
2333
  if it.get("published-print") and isinstance(it["published-print"].get("date-parts"), list):
2334
  pub_year = it["published-print"]["date-parts"][0][0]
 
2340
  doi = ""
2341
  doi_url = ""
2342
 
 
2343
  landing = (it.get("URL") or "") if isinstance(it.get("URL"), str) else ""
2344
  out.append({
2345
  "title": title,
 
2370
  continue
2371
 
2372
  doi = normalize_doi(it.get("doi", "")) or ""
 
2373
  if doi and doi.startswith("10.1163/"):
2374
  continue
2375
 
 
2386
 
2387
  out.append({
2388
  "title": it.get("title", ""),
2389
+ "doi": doi,
2390
+ "url": landing or "",
2391
  "year": it.get("publication_year") or (it.get("publication_date", "")[:4]),
2392
  "venue": (it.get("host_venue") or {}).get("display_name", ""),
2393
+ "type": oa_type,
2394
+ "source": "OpenAlex",
2395
  })
2396
  return out
2397
  except Exception as e:
 
2585
  return {"error": f"Unsupported web_search source: {src}"}
2586
 
2587
  # =============================================================================
2588
+ # REPORT GENERATION
2589
  # =============================================================================
2590
  def generate_report(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
 
 
2591
  payload = dict(data or {})
2592
  summary: Dict[str, Any] = {}
2593
 
 
2594
  prop = payload.get("property") or payload.get("property_name")
2595
  if prop:
2596
  payload["property"] = prop
2597
 
 
2598
  if not payload.get("property"):
2599
  qtxt = payload.get("questions") or payload.get("question") or ""
2600
  inferred_prop = infer_property_from_text(qtxt)
 
2605
  if psmiles:
2606
  payload["psmiles"] = psmiles
2607
 
 
2608
  if payload.get("target_value", None) is None:
2609
  qtxt = payload.get("questions") or payload.get("question") or ""
2610
  inferred_tgt = infer_target_value_from_text(qtxt, payload.get("property"))
2611
  if inferred_tgt is not None:
2612
  payload["target_value"] = float(inferred_tgt)
2613
 
 
2614
  if psmiles and "data_extraction" not in payload:
2615
  ex = self._run_data_extraction({"step": -1}, payload)
2616
  payload["data_extraction"] = ex
2617
  summary["data_extraction"] = ex
2618
 
 
2619
  if "data_extraction" in payload and "cl_encoding" not in payload:
2620
  cl = self._run_cl_encoding({"step": -1}, payload)
2621
  payload["cl_encoding"] = cl
2622
  summary["cl_encoding"] = cl
2623
 
 
2624
  if payload.get("property") and "property_prediction" not in payload:
2625
  pp = self._run_property_prediction({"step": -1}, payload)
2626
  payload["property_prediction"] = pp
2627
  summary["property_prediction"] = pp
2628
 
 
2629
  do_gen = bool(payload.get("generate", False)) or (payload.get("target_value", None) is not None)
2630
  if do_gen and payload.get("property") and payload.get("target_value", None) is not None:
2631
  gen = self._run_polymer_generation({"step": -1}, payload)
2632
  payload["polymer_generation"] = gen
2633
  summary["generation"] = gen
2634
 
 
2635
  q = payload.get("query") or payload.get("literature_query")
2636
  src = payload.get("source") or "all"
2637
  if q:
 
2652
  "questions": payload.get("questions") or payload.get("question") or "",
2653
  }
2654
 
 
2655
  report = _attach_source_domains(report)
2656
  report = _index_citable_sources(report)
2657
  report = _assign_tool_tags_to_report(report)
 
2661
  def _run_report_generation(self, step: Dict, data: Dict) -> Dict[str, Any]:
2662
  return self.generate_report(data)
2663
 
2664
+ # =============================================================================
2665
+ # COMPOSER
2666
+ # =============================================================================
2667
  def compose_gpt_style_answer(
2668
  self,
2669
  report: Dict[str, Any],
2670
  case_brief: str = "",
2671
  questions: str = "",
2672
  ) -> Tuple[str, List[str]]:
 
 
 
 
 
 
 
 
 
 
2673
  imgs: List[str] = []
2674
 
 
2675
  if isinstance(report, dict):
2676
  report = _attach_source_domains(report)
2677
  report = _index_citable_sources(report)
2678
  report = _assign_tool_tags_to_report(report)
2679
 
2680
  if self.openai_client is None:
 
2681
  md_lines = []
2682
  if case_brief:
2683
  md_lines.append(case_brief.strip())
 
2692
  md_lines.append(str(report))
2693
  md_lines.append("```")
2694
 
 
2695
  verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
2696
  if verb:
2697
  md_lines.append("\n---\n\n## Tool outputs (verbatim)\n")
 
2699
 
2700
  return "\n".join(md_lines), imgs
2701
 
 
2702
  try:
2703
  prompt = (
2704
  "You are PolyAgent - consider yourself as an expert in polymer science. Answer the user's questions using ONLY the provided report.\n"
 
2714
  "- NON-DUPLICATES: Do not repeat the same paper link. Each DOI/URL may appear at most once in the entire answer.\n"
2715
  "- Each major section should include at least 1 inline literature citation when relevant.\n"
2716
  "- Do NOT invent DOIs, URLs, titles, or sources.\n\n"
 
2717
  "OUTPUT RULES (STRICT):\n"
2718
  "- If a numeric value is not present in the report, write 'not available'.\n"
2719
  "- Preserve polymer endpoint tokens exactly as '[*]' in any pSMILES/SMILES shown.\n"
2720
  "- To prevent markdown mangling, put any pSMILES/SMILES inside code formatting.\n"
2721
+ "- Do not rewrite or tweak any tool outputs; if you refer to them, reference them by tag (e.g., [T]).\n\n"
2722
  f"CASE BRIEF:\n{case_brief}\n\n"
2723
  f"QUESTIONS:\n{questions}\n\n"
2724
  f"REPORT (JSON):\n{json.dumps(report, ensure_ascii=False)}\n"
 
2734
  )
2735
  txt = resp.choices[0].message.content or ""
2736
 
 
 
2737
  try:
2738
  min_cites = _infer_required_citation_count(questions or "", default_n=10)
2739
  txt = _ensure_distributed_inline_citations(txt, report, min_needed=min_cites)
2740
  except Exception:
2741
  pass
2742
 
 
 
2743
  try:
2744
  txt = _normalize_and_dedupe_literature_links(txt, report)
2745
  except Exception:
 
2750
  except Exception:
2751
  pass
2752
 
 
2753
  verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
2754
  if verb:
2755
  txt = txt.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
2756
 
2757
  return txt, imgs
2758
  except Exception as e:
 
2759
  md = f"OpenAI compose failed: {e}\n\n```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```"
 
2760
  verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
2761
  if verb:
2762
  md = md.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
2763
  return md, imgs
2764
 
2765
  # =============================================================================
2766
+ # VISUAL TOOLS
2767
  # =============================================================================
2768
  def _run_mol_render(self, step: Dict, data: Dict) -> Dict[str, Any]:
2769
  out_dir = Path("viz")
 
2817
  return {"png_path": png, "n": len(mols)}
2818
 
2819
  def _run_prop_attribution(self, step: Dict, data: Dict) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
2820
  out_dir = Path("viz")
2821
  out_dir.mkdir(parents=True, exist_ok=True)
2822
 
 
2827
  prop = canonical_property_name(data.get("property") or data.get("property_name") or "glass transition")
2828
  top_k = int(data.get("top_k_atoms", data.get("top_k", 12)))
2829
 
 
2830
  min_rel_importance = float(data.get("min_rel_importance", 0.25))
2831
  min_abs_importance = float(data.get("min_abs_importance", 0.0))
2832
 
2833
+ if prop not in self.PROPERTY_HEAD_PATHS:
2834
  return {"error": f"Unsupported property for attribution: {prop}"}
2835
  if not p:
2836
  return {"error": "no psmiles"}
 
2851
  if not isinstance(baseline, (float, int)):
2852
  return {"error": "Baseline prediction not numeric"}
2853
 
 
2854
  scores: Dict[int, float] = {}
2855
  for idx in range(num_atoms):
2856
  try:
 
2858
  tmp.GetAtomWithIdx(idx).SetAtomicNum(0) # wildcard
2859
  mutated = tmp.GetMol()
2860
  mut_smiles = Chem.MolToSmiles(mutated)
2861
+ mut_psmiles = normalize_generated_psmiles_out(mut_smiles)
2862
  except Exception:
2863
  scores[idx] = 0.0
2864
  continue
 
2870
  else:
2871
  scores[idx] = float(baseline) - float(mut_val)
2872
 
 
2873
  max_abs = max((abs(v) for v in scores.values()), default=0.0)
2874
  rel_thresh = (min_rel_importance * max_abs) if max_abs > 0 else 0.0
2875
  thresh = max(float(min_abs_importance), float(rel_thresh))
 
2880
  selected = [i for i, v in ranked if abs(v) >= thresh]
2881
  selected = selected[:k_cap]
2882
 
 
2883
  if not selected and ranked:
2884
  selected = [ranked[0][0]]
2885
 
 
2886
  atom_colors: Dict[int, tuple] = {}
2887
  sel_scores = np.array([scores[i] for i in selected], dtype=float)
2888
  if cm is not None and sel_scores.size > 0:
 
2930
  except Exception as e:
2931
  return {"error": f"prop_attribution rendering failed: {e}"}
2932
 
 
2933
  def process_query(self, user_query: str, user_inputs: Dict[str, Any] = None) -> Dict[str, Any]:
2934
  plan = self.analyze_query(user_query)
2935
  results = self.execute_plan(plan, user_inputs)
 
2937
 
2938
 
2939
  if __name__ == "__main__":
2940
+ cfg = OrchestratorConfig(paths=PathsConfig())
2941
  orch = PolymerOrchestrator(cfg)
2942
+ print("PolymerOrchestrator ready (5M heads + 5M inverse-design + LLM planner + occlusion explainability).")