Spaces:
Running
Running
manpreet88 commited on
Commit ·
30fd755
1
Parent(s): 8b37da0
Update orchestrator.py
Browse files- PolyAgent/orchestrator.py +174 -288
PolyAgent/orchestrator.py
CHANGED
|
@@ -1,3 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import json
|
|
@@ -6,7 +20,7 @@ import sys
|
|
| 6 |
from pathlib import Path
|
| 7 |
from typing import Dict, Any, List, Optional, Tuple
|
| 8 |
from urllib.parse import urlparse
|
| 9 |
-
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
|
@@ -47,7 +61,7 @@ try:
|
|
| 47 |
except Exception:
|
| 48 |
spm = None
|
| 49 |
|
| 50 |
-
# Optional: selfies (for SELFIES→SMILES/PSMILES conversion
|
| 51 |
try:
|
| 52 |
import selfies as sf
|
| 53 |
except Exception:
|
|
@@ -57,11 +71,36 @@ RDKit_AVAILABLE = Chem is not None
|
|
| 57 |
SELFIES_AVAILABLE = sf is not None
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# =============================================================================
|
| 61 |
# DOI NORMALIZATION / RESOLUTION HELPERS
|
| 62 |
# =============================================================================
|
| 63 |
_DOI_RE = re.compile(r"^10\.\d{4,9}/\S+$", re.IGNORECASE)
|
| 64 |
|
|
|
|
| 65 |
def normalize_doi(raw: str) -> Optional[str]:
|
| 66 |
if not isinstance(raw, str):
|
| 67 |
return None
|
|
@@ -75,10 +114,12 @@ def normalize_doi(raw: str) -> Optional[str]:
|
|
| 75 |
s = s.rstrip(").,;]}")
|
| 76 |
return s if _DOI_RE.match(s) else None
|
| 77 |
|
|
|
|
| 78 |
def doi_to_url(doi: str) -> str:
|
| 79 |
# doi is assumed normalized
|
| 80 |
return f"https://doi.org/{doi}"
|
| 81 |
|
|
|
|
| 82 |
def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
|
| 83 |
"""
|
| 84 |
Best-effort resolver check. Keeps pipeline robust against dead/unregistered DOIs.
|
|
@@ -95,8 +136,9 @@ def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
|
|
| 95 |
except Exception:
|
| 96 |
return False
|
| 97 |
|
|
|
|
| 98 |
# =============================================================================
|
| 99 |
-
# CITATION / DOMAIN TAGGING HELPERS
|
| 100 |
# =============================================================================
|
| 101 |
def _url_to_domain(url: str) -> Optional[str]:
|
| 102 |
if not isinstance(url, str) or not url.strip():
|
|
@@ -133,10 +175,12 @@ def _url_to_domain(url: str) -> Optional[str]:
|
|
| 133 |
except Exception:
|
| 134 |
return None
|
| 135 |
|
|
|
|
| 136 |
def _attach_source_domains(obj: Any) -> Any:
|
| 137 |
"""
|
| 138 |
Recursively add a short source_domain field where URLs are present.
|
| 139 |
-
This enables domain-style citations like "(nature.com)"
|
|
|
|
| 140 |
"""
|
| 141 |
if isinstance(obj, list):
|
| 142 |
return [_attach_source_domains(x) for x in obj]
|
|
@@ -160,6 +204,7 @@ def _attach_source_domains(obj: Any) -> Any:
|
|
| 160 |
def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
|
| 161 |
"""
|
| 162 |
Add 'cite_tag' fields for citable web/RAG items using DOI-first URL tags.
|
|
|
|
| 163 |
Requirement:
|
| 164 |
- Paper citations must use the COMPLETE DOI URL (https://doi.org/...) as the bracket text.
|
| 165 |
- If DOI is not available, fall back to the best http(s) URL.
|
|
@@ -181,7 +226,7 @@ def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 181 |
return False
|
| 182 |
|
| 183 |
def get_best_url(d: Dict[str, Any]) -> Optional[str]:
|
| 184 |
-
# DOI-first
|
| 185 |
doi = normalize_doi(d.get("doi", ""))
|
| 186 |
if doi:
|
| 187 |
return doi_to_url(doi)
|
|
@@ -199,23 +244,16 @@ def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 199 |
out = {k: walk_and_tag(v) for k, v in node.items()}
|
| 200 |
|
| 201 |
if is_citable_item(out):
|
| 202 |
-
# Citation tag MUST be DOI URL (preferred) or best URL (fallback).
|
| 203 |
url = get_best_url(out)
|
| 204 |
if isinstance(url, str) and url.startswith(("http://", "https://")):
|
| 205 |
-
# If an existing cite_tag is non-URL (e.g., a domain tag), replace it.
|
| 206 |
cur = out.get("cite_tag")
|
| 207 |
if not (isinstance(cur, str) and cur.strip().startswith(("http://", "https://"))):
|
| 208 |
out["cite_tag"] = url.strip()
|
| 209 |
-
else:
|
| 210 |
-
# If we cannot form a URL, leave as-is (should be rare due to is_citable_item).
|
| 211 |
-
pass
|
| 212 |
|
| 213 |
-
# Maintain a compact index (optional, harmless for UIs)
|
| 214 |
url = get_best_url(out)
|
| 215 |
dom = out.get("source_domain") or (_url_to_domain(url) if url else None) or "source"
|
| 216 |
citation_index["sources"].append(
|
| 217 |
{
|
| 218 |
-
# tag is the bracket text requirement (DOI URL or URL)
|
| 219 |
"tag": out.get("cite_tag") if isinstance(out.get("cite_tag"), str) else url,
|
| 220 |
"domain": dom,
|
| 221 |
"title": out.get("title") or out.get("name") or "Untitled",
|
|
@@ -236,10 +274,9 @@ def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 236 |
return report
|
| 237 |
|
| 238 |
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
# =============================================================================
|
| 243 |
_CITE_COUNT_PATTERNS = [
|
| 244 |
r"(?:at\s+least\s+)?(\d{1,3})\s*(?:citations|citation|papers|paper|sources|source|references|reference)\b",
|
| 245 |
r"\bcite\s+(\d{1,3})\s*(?:papers|paper|sources|source|references|reference|citations|citation)\b",
|
|
@@ -262,8 +299,8 @@ def _infer_required_citation_count(text: str, default_n: int = 10) -> int:
|
|
| 262 |
|
| 263 |
def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[str, str]]:
|
| 264 |
"""
|
| 265 |
-
Return unique (
|
| 266 |
-
|
| 267 |
"""
|
| 268 |
out: List[Tuple[str, str]] = []
|
| 269 |
seen: set = set()
|
|
@@ -280,7 +317,6 @@ def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[st
|
|
| 280 |
url = s.get("url")
|
| 281 |
if not isinstance(url, str) or not url.startswith(("http://", "https://")):
|
| 282 |
continue
|
| 283 |
-
# cite_text is DOI URL (tag) if present; else fall back to the URL itself.
|
| 284 |
cite_text = s.get("tag") if isinstance(s.get("tag"), str) and s.get("tag").strip() else url
|
| 285 |
if not isinstance(cite_text, str) or not cite_text.strip():
|
| 286 |
cite_text = url
|
|
@@ -310,18 +346,15 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
|
|
| 310 |
if not citations:
|
| 311 |
return md
|
| 312 |
|
| 313 |
-
# Count existing literature links by URL (any markdown link).
|
| 314 |
existing_urls = set(re.findall(r"\[[^\]]+\]\((https?://[^)]+)\)", md))
|
| 315 |
need = max(0, int(min_needed) - len(existing_urls))
|
| 316 |
if need <= 0:
|
| 317 |
return md
|
| 318 |
|
| 319 |
-
# Only use citations not already present.
|
| 320 |
remaining: List[Tuple[str, str]] = [(d, u) for (d, u) in citations if u not in existing_urls]
|
| 321 |
if not remaining:
|
| 322 |
return md
|
| 323 |
|
| 324 |
-
# Split by fenced code blocks; do not inject inside them.
|
| 325 |
parts = re.split(r"(```[\s\S]*?```)", md)
|
| 326 |
rem_i = 0
|
| 327 |
|
|
@@ -331,7 +364,6 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
|
|
| 331 |
if part.startswith("```") and part.endswith("```"):
|
| 332 |
continue
|
| 333 |
|
| 334 |
-
# Split into paragraph blocks (preserve blank-line separators).
|
| 335 |
segs = re.split(r"(\n\s*\n)", part)
|
| 336 |
for si in range(0, len(segs), 2):
|
| 337 |
if rem_i >= len(remaining) or need <= 0:
|
|
@@ -339,27 +371,24 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
|
|
| 339 |
para = segs[si]
|
| 340 |
if not isinstance(para, str) or not para.strip():
|
| 341 |
continue
|
| 342 |
-
# Skip headings.
|
| 343 |
if para.lstrip().startswith("#"):
|
| 344 |
continue
|
| 345 |
-
# Skip paragraphs that already contain at least one markdown link.
|
| 346 |
if re.search(r"\[[^\]]+\]\((https?://[^)]+)\)", para):
|
| 347 |
continue
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
| 352 |
continue
|
| 353 |
|
| 354 |
cite_text, url = remaining[rem_i]
|
| 355 |
-
# Requirement: bracket text is the COMPLETE DOI URL (or URL fallback).
|
| 356 |
segs[si] = para.rstrip() + f" [{cite_text}]({url})"
|
| 357 |
rem_i += 1
|
| 358 |
need -= 1
|
| 359 |
|
| 360 |
parts[pi] = "".join(segs)
|
| 361 |
|
| 362 |
-
# Second pass (if still need citations): allow any non-heading paragraph without links.
|
| 363 |
if need > 0 and rem_i < len(remaining):
|
| 364 |
md2 = "".join(parts)
|
| 365 |
parts2 = re.split(r"(```[\s\S]*?```)", md2)
|
|
@@ -391,9 +420,9 @@ def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_ne
|
|
| 391 |
|
| 392 |
def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> str:
|
| 393 |
"""
|
| 394 |
-
Enforce
|
| 395 |
-
- Link text must be
|
| 396 |
-
- Each DOI/URL
|
| 397 |
Only operates outside fenced code blocks.
|
| 398 |
"""
|
| 399 |
if not isinstance(md, str) or not md.strip():
|
|
@@ -401,7 +430,6 @@ def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> s
|
|
| 401 |
if not isinstance(report, dict):
|
| 402 |
return md
|
| 403 |
|
| 404 |
-
# Build url -> preferred_text mapping (DOI URL / URL)
|
| 405 |
url_to_text: Dict[str, str] = {}
|
| 406 |
ci = report.get("citation_index", {})
|
| 407 |
sources = ci.get("sources") if isinstance(ci, dict) else None
|
|
@@ -421,22 +449,19 @@ def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> s
|
|
| 421 |
|
| 422 |
def _rewrite_and_dedupe(text: str) -> str:
|
| 423 |
def repl(m: re.Match) -> str:
|
| 424 |
-
txt = m.group(1)
|
| 425 |
url = m.group(2).strip()
|
| 426 |
if url in seen_urls:
|
| 427 |
-
# remove duplicate citation entirely (and any leading space before it if present)
|
| 428 |
return ""
|
| 429 |
seen_urls.add(url)
|
| 430 |
pref = url_to_text.get(url, url)
|
| 431 |
return f"[{pref}]({url})"
|
| 432 |
-
|
| 433 |
return re.sub(r"\[([^\]]+)\]\((https?://[^)]+)\)", repl, text)
|
| 434 |
|
| 435 |
for i, part in enumerate(parts):
|
| 436 |
if part.startswith("```") and part.endswith("```"):
|
| 437 |
continue
|
| 438 |
parts[i] = _rewrite_and_dedupe(part)
|
| 439 |
-
# Cleanup: collapse double spaces created by removals
|
| 440 |
parts[i] = re.sub(r"[ \t]{2,}", " ", parts[i])
|
| 441 |
parts[i] = re.sub(r"\n{3,}", "\n\n", parts[i])
|
| 442 |
|
|
@@ -446,7 +471,6 @@ def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> s
|
|
| 446 |
def autolink_doi_urls(md: str) -> str:
|
| 447 |
"""
|
| 448 |
Wrap bare DOI URLs in Markdown links outside code blocks.
|
| 449 |
-
Prevents plain DOI URLs from rendering as non-clickable text.
|
| 450 |
"""
|
| 451 |
if not md:
|
| 452 |
return md
|
|
@@ -457,15 +481,18 @@ def autolink_doi_urls(md: str) -> str:
|
|
| 457 |
parts[i] = re.sub(
|
| 458 |
r"(?<!\]\()(?P<u>https?://doi\.org/10\.\d{4,9}/[^\s\)\],;]+)",
|
| 459 |
lambda m: f"[{m.group('u')}]({m.group('u')})",
|
| 460 |
-
|
| 461 |
flags=re.IGNORECASE,
|
| 462 |
)
|
| 463 |
return "".join(parts)
|
| 464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
| 466 |
"""
|
| 467 |
-
Ensure each tool output has a [T
|
| 468 |
-
This does NOT modify tool outputs beyond adding a 'cite_tag' key when missing.
|
| 469 |
"""
|
| 470 |
if not isinstance(report, dict):
|
| 471 |
return report
|
|
@@ -474,8 +501,7 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 474 |
if not isinstance(tool_outputs, dict):
|
| 475 |
return report
|
| 476 |
|
| 477 |
-
|
| 478 |
-
ordered_tools = [
|
| 479 |
"data_extraction",
|
| 480 |
"cl_encoding",
|
| 481 |
"property_prediction",
|
|
@@ -485,12 +511,10 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 485 |
"report_generation",
|
| 486 |
]
|
| 487 |
|
| 488 |
-
# Tag assignment: keep existing cite_tags if present
|
| 489 |
tool_tag_map: Dict[str, str] = {}
|
| 490 |
tag = "[T]"
|
| 491 |
|
| 492 |
-
|
| 493 |
-
for tool in ordered_tools:
|
| 494 |
node = tool_outputs.get(tool)
|
| 495 |
if node is None:
|
| 496 |
continue
|
|
@@ -498,7 +522,6 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 498 |
if isinstance(node, dict) and not node.get("cite_tag"):
|
| 499 |
node["cite_tag"] = tag
|
| 500 |
|
| 501 |
-
# Second pass: any remaining tools in tool_outputs
|
| 502 |
for tool, node in tool_outputs.items():
|
| 503 |
if tool in tool_tag_map or node is None:
|
| 504 |
continue
|
|
@@ -506,11 +529,9 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 506 |
if isinstance(node, dict) and not node.get("cite_tag"):
|
| 507 |
node["cite_tag"] = tag
|
| 508 |
|
| 509 |
-
# Also tag summary nodes (best-effort, no structural assumptions)
|
| 510 |
try:
|
| 511 |
summary = report.get("summary", {}) or {}
|
| 512 |
if isinstance(summary, dict):
|
| 513 |
-
# common mapping
|
| 514 |
key_to_tool = {
|
| 515 |
"data_extraction": "data_extraction",
|
| 516 |
"cl_encoding": "cl_encoding",
|
|
@@ -534,7 +555,7 @@ def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
|
| 534 |
|
| 535 |
def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
|
| 536 |
"""
|
| 537 |
-
Render tool outputs as verbatim JSON blocks (no
|
| 538 |
"""
|
| 539 |
if not isinstance(report, dict):
|
| 540 |
return ""
|
|
@@ -543,7 +564,6 @@ def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
|
|
| 543 |
if not isinstance(tool_outputs, dict):
|
| 544 |
return ""
|
| 545 |
|
| 546 |
-
# Prefer a stable display order; include any extra keys afterward
|
| 547 |
preferred = [
|
| 548 |
"data_extraction",
|
| 549 |
"cl_encoding",
|
|
@@ -571,28 +591,19 @@ def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
|
|
| 571 |
|
| 572 |
|
| 573 |
# =============================================================================
|
| 574 |
-
# PICKLE / JOBLIB COMPATIBILITY SHIMS
|
| 575 |
# =============================================================================
|
| 576 |
class LatentPropertyModel:
|
| 577 |
"""
|
| 578 |
Compatibility shim for joblib/pickle artifacts saved with references like:
|
| 579 |
__main__.LatentPropertyModel
|
| 580 |
-
|
| 581 |
-
The original training code likely defined this in a script, so pickle recorded it under __main__.
|
| 582 |
-
When loading from Gradio, __main__ is different, so unpickling fails.
|
| 583 |
-
|
| 584 |
-
This shim is intentionally minimal:
|
| 585 |
-
- pickle will restore attributes into this object
|
| 586 |
-
- predict(...) attempts to delegate to a plausible underlying model attribute if present
|
| 587 |
"""
|
| 588 |
def predict(self, X):
|
| 589 |
-
# Common patterns: wrapper stores underlying estimator under one of these attributes.
|
| 590 |
for attr in ("model", "gpr", "gpr_model", "estimator", "predictor", "_model", "_gpr"):
|
| 591 |
if hasattr(self, attr):
|
| 592 |
obj = getattr(self, attr)
|
| 593 |
if hasattr(obj, "predict"):
|
| 594 |
return obj.predict(X)
|
| 595 |
-
# If the wrapper itself has been restored with a custom predict, this will never be hit.
|
| 596 |
raise AttributeError(
|
| 597 |
"LatentPropertyModel shim could not find an underlying predictor. "
|
| 598 |
"Artifact expects a wrapped model attribute with a .predict method."
|
|
@@ -602,7 +613,6 @@ class LatentPropertyModel:
|
|
| 602 |
def _install_unpickle_shims() -> None:
|
| 603 |
"""
|
| 604 |
Ensure that any classes pickled under __main__ are available at load time.
|
| 605 |
-
This is critical for joblib artifacts created from scripts (training/fit scripts).
|
| 606 |
"""
|
| 607 |
main_mod = sys.modules.get("__main__")
|
| 608 |
if main_mod is not None and not hasattr(main_mod, "LatentPropertyModel"):
|
|
@@ -620,7 +630,6 @@ def _safe_joblib_load(path: str):
|
|
| 620 |
return joblib.load(path)
|
| 621 |
except Exception as e:
|
| 622 |
msg = str(e)
|
| 623 |
-
# Targeted fix for your exact failure mode
|
| 624 |
if "Can't get attribute 'LatentPropertyModel' on <module '__main__'" in msg:
|
| 625 |
_install_unpickle_shims()
|
| 626 |
return joblib.load(path)
|
|
@@ -628,18 +637,50 @@ def _safe_joblib_load(path: str):
|
|
| 628 |
|
| 629 |
|
| 630 |
# =============================================================================
|
| 631 |
-
#
|
| 632 |
# =============================================================================
|
| 633 |
-
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
|
| 637 |
# =============================================================================
|
| 638 |
-
# Property name canonicalization
|
| 639 |
# =============================================================================
|
| 640 |
def canonical_property_name(name: str) -> str:
|
| 641 |
"""
|
| 642 |
-
Map user/tool inputs to the canonical keys used in
|
| 643 |
"""
|
| 644 |
if not isinstance(name, str):
|
| 645 |
return ""
|
|
@@ -663,15 +704,11 @@ def canonical_property_name(name: str) -> str:
|
|
| 663 |
return aliases.get(s, s)
|
| 664 |
|
| 665 |
|
| 666 |
-
# =============================================================================
|
| 667 |
-
# NEW: best-effort inference of property + target_value from questions text
|
| 668 |
-
# (used only when callers omit property/target_value but provide questions)
|
| 669 |
-
# =============================================================================
|
| 670 |
_NUM_RE = r"[-+]?\d+(?:\.\d+)?"
|
| 671 |
|
|
|
|
| 672 |
def infer_property_from_text(text: str) -> Optional[str]:
|
| 673 |
s = (text or "").lower()
|
| 674 |
-
# explicit "property: ..."
|
| 675 |
m = re.search(r"\bproperty\b\s*[:=]\s*([a-zA-Z _-]+)", s)
|
| 676 |
if m:
|
| 677 |
cand = m.group(1).strip().lower()
|
|
@@ -698,6 +735,7 @@ def infer_property_from_text(text: str) -> Optional[str]:
|
|
| 698 |
return "density"
|
| 699 |
return None
|
| 700 |
|
|
|
|
| 701 |
def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[float]:
|
| 702 |
sl = (text or "").lower()
|
| 703 |
|
|
@@ -729,7 +767,6 @@ def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[flo
|
|
| 729 |
except Exception:
|
| 730 |
pass
|
| 731 |
|
| 732 |
-
# token-near-number fallback (within 80 chars)
|
| 733 |
tokens = []
|
| 734 |
if prop == "glass transition":
|
| 735 |
tokens = ["tg", "glass transition"]
|
|
@@ -754,33 +791,9 @@ def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[flo
|
|
| 754 |
|
| 755 |
return None
|
| 756 |
|
| 757 |
-
PROPERTY_HEAD_PATHS = {
|
| 758 |
-
"density": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "density", "best_run_checkpoint.pt"),
|
| 759 |
-
"glass transition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "glass_transition", "best_run_checkpoint.pt"),
|
| 760 |
-
"melting": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "melting", "best_run_checkpoint.pt"),
|
| 761 |
-
"specific volume": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "specific_volume", "best_run_checkpoint.pt"),
|
| 762 |
-
"thermal decomposition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "thermal_decomposition", "best_run_checkpoint.pt"),
|
| 763 |
-
}
|
| 764 |
-
|
| 765 |
-
PROPERTY_HEAD_META = {
|
| 766 |
-
"density": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "density", "best_run_metadata.json"),
|
| 767 |
-
"glass transition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "glass_transition", "best_run_metadata.json"),
|
| 768 |
-
"melting": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "melting", "best_run_metadata.json"),
|
| 769 |
-
"specific volume": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "specific_volume", "best_run_metadata.json"),
|
| 770 |
-
"thermal decomposition": os.path.join(DOWNSTREAM_BESTWEIGHTS_5M_DIR, "thermal_decomposition", "best_run_metadata.json"),
|
| 771 |
-
}
|
| 772 |
-
|
| 773 |
-
GENERATOR_DIRS = {
|
| 774 |
-
"density": os.path.join(INVERSE_DESIGN_5M_DIR, "density"),
|
| 775 |
-
"glass transition": os.path.join(INVERSE_DESIGN_5M_DIR, "glass_transition"),
|
| 776 |
-
"melting": os.path.join(INVERSE_DESIGN_5M_DIR, "melting"),
|
| 777 |
-
"specific volume": os.path.join(INVERSE_DESIGN_5M_DIR, "specific_volume"),
|
| 778 |
-
"thermal decomposition": os.path.join(INVERSE_DESIGN_5M_DIR, "thermal_decomposition"),
|
| 779 |
-
}
|
| 780 |
-
|
| 781 |
|
| 782 |
# =============================================================================
|
| 783 |
-
# Tokenizers
|
| 784 |
# =============================================================================
|
| 785 |
class SimpleCharTokenizer:
|
| 786 |
def __init__(self, vocab_chars: List[str], special_tokens=("<pad>", "<s>", "</s>", "<unk>")):
|
|
@@ -840,7 +853,6 @@ class SentencePieceTokenizerWrapper:
|
|
| 840 |
blocked.append(tid)
|
| 841 |
setattr(self, "_blocked_ids", blocked)
|
| 842 |
|
| 843 |
-
# Safety: require '*' token
|
| 844 |
if self.PieceToId("*") is None:
|
| 845 |
raise RuntimeError("SentencePiece tokenizer loaded but '*' token not found – aborting for safe PSMILES generation.")
|
| 846 |
|
|
@@ -878,16 +890,18 @@ def psmiles_to_rdkit_smiles(psmiles: str) -> str:
|
|
| 878 |
s = re.sub(r"\*", "[*]", s)
|
| 879 |
return s
|
| 880 |
|
| 881 |
-
|
| 882 |
_AT_BRACKET_UI_RE = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
|
| 883 |
|
|
|
|
| 884 |
def replace_at_with_star(psmiles: str) -> str:
|
| 885 |
if not isinstance(psmiles, str) or not psmiles:
|
| 886 |
return psmiles
|
| 887 |
return _AT_BRACKET_UI_RE.sub("[*]", psmiles)
|
| 888 |
|
|
|
|
| 889 |
# =============================================================================
|
| 890 |
-
# SELFIES utilities
|
| 891 |
# =============================================================================
|
| 892 |
_SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]")
|
| 893 |
|
|
@@ -933,32 +947,25 @@ def selfies_to_smiles(selfies_str: str) -> str:
|
|
| 933 |
def pselfies_to_psmiles(selfies_str: str) -> str:
|
| 934 |
"""
|
| 935 |
For this orchestrator we treat pSELFIES→PSMILES as SELFIES→canonical SMILES.
|
| 936 |
-
The G2 training script used a more elaborate At/[*] polymer mapping; if you
|
| 937 |
-
want that exact behaviour, you can replace this with the full pselfies_to_psmiles
|
| 938 |
-
utilities from G2.py.
|
| 939 |
"""
|
| 940 |
return selfies_to_smiles(selfies_str)
|
| 941 |
|
| 942 |
|
| 943 |
# =============================================================================
|
| 944 |
-
# SELFIES-TED decoder
|
| 945 |
# =============================================================================
|
| 946 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 947 |
SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted")
|
| 948 |
|
| 949 |
-
# Generation hyperparameters (mirroring G2 defaults)
|
| 950 |
GEN_MAX_LEN = 256
|
| 951 |
GEN_MIN_LEN = 10
|
| 952 |
GEN_TOP_P = 0.92
|
| 953 |
GEN_TEMPERATURE = 1.0
|
| 954 |
GEN_REPETITION_PENALTY = 1.05
|
| 955 |
-
LATENT_NOISE_STD_GEN = 0.15
|
| 956 |
|
| 957 |
|
| 958 |
def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
|
| 959 |
-
"""
|
| 960 |
-
Small helper to make HF loading more robust, copied from G2 spirit.
|
| 961 |
-
"""
|
| 962 |
import time
|
| 963 |
last_err = None
|
| 964 |
for t in range(max_tries):
|
|
@@ -974,7 +981,7 @@ def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
|
|
| 974 |
|
| 975 |
def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
|
| 976 |
"""
|
| 977 |
-
Load tokenizer + seq2seq model for SELFIES-TED
|
| 978 |
"""
|
| 979 |
def _load_tok():
|
| 980 |
return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True)
|
|
@@ -989,8 +996,7 @@ def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
|
|
| 989 |
|
| 990 |
class CLConditionedSelfiesTEDGenerator(nn.Module):
|
| 991 |
"""
|
| 992 |
-
|
| 993 |
-
into a fixed-length memory that conditions a SELFIES-TED seq2seq model.
|
| 994 |
"""
|
| 995 |
def __init__(self, tok, seq2seq_model, cl_emb_dim: int = 600, mem_len: int = 4):
|
| 996 |
super().__init__()
|
|
@@ -1038,9 +1044,6 @@ class CLConditionedSelfiesTEDGenerator(nn.Module):
|
|
| 1038 |
temperature: float = GEN_TEMPERATURE,
|
| 1039 |
repetition_penalty: float = GEN_REPETITION_PENALTY,
|
| 1040 |
) -> List[str]:
|
| 1041 |
-
"""
|
| 1042 |
-
Latent→pSELFIES generation, as in G2.
|
| 1043 |
-
"""
|
| 1044 |
self.eval()
|
| 1045 |
z = z.to(next(self.parameters()).device)
|
| 1046 |
enc_out, attn = self.build_encoder_outputs(z)
|
|
@@ -1063,25 +1066,17 @@ class CLConditionedSelfiesTEDGenerator(nn.Module):
|
|
| 1063 |
|
| 1064 |
|
| 1065 |
# =============================================================================
|
| 1066 |
-
# Latent
|
| 1067 |
# =============================================================================
|
| 1068 |
def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 1069 |
-
"""
|
| 1070 |
-
Mirror G2's predict_latent_property(model, z):
|
| 1071 |
-
- PCA transform if present
|
| 1072 |
-
- GPR predict (scaled y)
|
| 1073 |
-
- inverse-transform via y_scaler if present
|
| 1074 |
-
"""
|
| 1075 |
z_use = np.asarray(z, dtype=np.float32)
|
| 1076 |
if z_use.ndim == 1:
|
| 1077 |
z_use = z_use.reshape(1, -1)
|
| 1078 |
|
| 1079 |
-
# Optional PCA
|
| 1080 |
pca = getattr(latent_model, "pca", None)
|
| 1081 |
if pca is not None:
|
| 1082 |
z_use = pca.transform(z_use.astype(np.float32))
|
| 1083 |
|
| 1084 |
-
# GPR or wrapped predictor
|
| 1085 |
gpr = getattr(latent_model, "gpr", None)
|
| 1086 |
if gpr is not None and hasattr(gpr, "predict"):
|
| 1087 |
y_s = gpr.predict(z_use)
|
|
@@ -1092,7 +1087,6 @@ def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarr
|
|
| 1092 |
|
| 1093 |
y_s = np.array(y_s, dtype=np.float32).reshape(-1)
|
| 1094 |
|
| 1095 |
-
# Optional scaler to get back to original units
|
| 1096 |
y_scaler = getattr(latent_model, "y_scaler", None)
|
| 1097 |
if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
|
| 1098 |
y_u = y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1)
|
|
@@ -1103,7 +1097,7 @@ def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarr
|
|
| 1103 |
|
| 1104 |
|
| 1105 |
# =============================================================================
|
| 1106 |
-
# Legacy models
|
| 1107 |
# =============================================================================
|
| 1108 |
class TransformerDecoderOnly(nn.Module):
|
| 1109 |
def __init__(
|
|
@@ -1197,21 +1191,23 @@ class InverseDesignDecoder(nn.Module):
|
|
| 1197 |
|
| 1198 |
|
| 1199 |
# =============================================================================
|
| 1200 |
-
#
|
| 1201 |
# =============================================================================
|
| 1202 |
class OrchestratorConfig:
|
| 1203 |
-
def __init__(self):
|
|
|
|
|
|
|
| 1204 |
self.base_dir = "."
|
| 1205 |
-
self.cl_weights_path =
|
| 1206 |
-
self.chroma_db_path =
|
| 1207 |
self.rag_embedding_model = "text-embedding-3-small"
|
| 1208 |
|
| 1209 |
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
|
| 1210 |
self.model = os.getenv("OPENAI_MODEL", "gpt-4.1")
|
| 1211 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1212 |
|
| 1213 |
-
self.spm_model_path =
|
| 1214 |
-
self.spm_vocab_path =
|
| 1215 |
|
| 1216 |
self.springer_api_key = os.getenv("SPRINGER_NATURE_API_KEY", "")
|
| 1217 |
self.semantic_scholar_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
|
|
@@ -1223,7 +1219,7 @@ class OrchestratorConfig:
|
|
| 1223 |
"property_prediction": True,
|
| 1224 |
"polymer_generation": True,
|
| 1225 |
"web_search": True,
|
| 1226 |
-
"report_generation": True, #
|
| 1227 |
"mol_render": True,
|
| 1228 |
"gen_grid": True,
|
| 1229 |
"prop_attribution": True,
|
|
@@ -1268,7 +1264,7 @@ TOOL_DESCRIPTIONS = {
|
|
| 1268 |
"CrossRef, OpenAlex, EuropePMC, arXiv, Semantic Scholar, Springer Nature (API key), Internet Archive"
|
| 1269 |
),
|
| 1270 |
},
|
| 1271 |
-
"report_generation": {
|
| 1272 |
"name": "Report Generation",
|
| 1273 |
"description": (
|
| 1274 |
"Synthesizes available tool outputs into a single structured report object "
|
|
@@ -1299,6 +1295,10 @@ TOOL_DESCRIPTIONS = {
|
|
| 1299 |
class PolymerOrchestrator:
|
| 1300 |
def __init__(self, config: OrchestratorConfig):
|
| 1301 |
self.config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1302 |
self._openai_client = None
|
| 1303 |
self._openai_unavailable_reason = None
|
| 1304 |
self._data_extractor = None
|
|
@@ -1315,6 +1315,9 @@ class PolymerOrchestrator:
|
|
| 1315 |
|
| 1316 |
self.system_prompt = self._build_system_prompt()
|
| 1317 |
|
|
|
|
|
|
|
|
|
|
| 1318 |
@property
|
| 1319 |
def openai_client(self):
|
| 1320 |
if self._openai_client is None:
|
|
@@ -1336,7 +1339,7 @@ class PolymerOrchestrator:
|
|
| 1336 |
return (
|
| 1337 |
"You are the tool-planning module for **PolyAgent**, a polymer-science agent.\n"
|
| 1338 |
"Your job is to inspect the user's questions and decide which tools\n"
|
| 1339 |
-
"to run in which order.
|
| 1340 |
"Critical tool dependencies:\n"
|
| 1341 |
"- property_prediction should run AFTER cl_encoding when possible and should reuse cl_encoding.embedding.\n"
|
| 1342 |
"- polymer_generation is inverse-design and REQUIRES target_value (property -> PSMILES).\n\n"
|
|
@@ -1345,7 +1348,7 @@ class PolymerOrchestrator:
|
|
| 1345 |
)
|
| 1346 |
|
| 1347 |
# =============================================================================
|
| 1348 |
-
# Planner: LLM tool-calling
|
| 1349 |
# =============================================================================
|
| 1350 |
def analyze_query(self, user_query: str) -> Dict[str, Any]:
|
| 1351 |
schema_keys = ["analysis", "tools_required", "execution_plan"]
|
|
@@ -1394,7 +1397,6 @@ class PolymerOrchestrator:
|
|
| 1394 |
}
|
| 1395 |
}
|
| 1396 |
|
| 1397 |
-
# Preferred: function/tool-calling
|
| 1398 |
try:
|
| 1399 |
response = self.openai_client.chat.completions.create(
|
| 1400 |
model=self.config.model,
|
|
@@ -1420,7 +1422,6 @@ class PolymerOrchestrator:
|
|
| 1420 |
|
| 1421 |
raise RuntimeError("Tool-calling plan not returned; falling back to JSON mode.")
|
| 1422 |
except Exception:
|
| 1423 |
-
# Safe fallback: JSON response_format (still LLM-generated, not rule-based)
|
| 1424 |
try:
|
| 1425 |
response = self.openai_client.chat.completions.create(
|
| 1426 |
model=self.config.model,
|
|
@@ -1466,7 +1467,7 @@ class PolymerOrchestrator:
|
|
| 1466 |
output = self._run_polymer_generation(step, intermediate_data)
|
| 1467 |
elif tool_name == "web_search":
|
| 1468 |
output = self._run_web_search(step, intermediate_data)
|
| 1469 |
-
elif tool_name == "report_generation":
|
| 1470 |
output = self._run_report_generation(step, intermediate_data)
|
| 1471 |
elif tool_name == "mol_render":
|
| 1472 |
output = self._run_mol_render(step, intermediate_data)
|
|
@@ -1580,7 +1581,6 @@ class PolymerOrchestrator:
|
|
| 1580 |
"year": meta.get("year", ""),
|
| 1581 |
"source": meta.get("source", meta.get("source_path", "")),
|
| 1582 |
"venue": meta.get("venue", meta.get("journal", "")),
|
| 1583 |
-
# NEW: preserve citable identifiers when present in metadata
|
| 1584 |
"url": meta.get("url") or meta.get("link") or meta.get("href") or "",
|
| 1585 |
"doi": meta.get("doi") or "",
|
| 1586 |
})
|
|
@@ -1705,9 +1705,10 @@ class PolymerOrchestrator:
|
|
| 1705 |
"attention_mask": torch.ones(1, 2048, dtype=torch.bool, device=self.config.device),
|
| 1706 |
}
|
| 1707 |
|
| 1708 |
-
# psmiles
|
| 1709 |
if self._psmiles_tokenizer is None:
|
| 1710 |
try:
|
|
|
|
| 1711 |
self._psmiles_tokenizer = build_psmiles_tokenizer()
|
| 1712 |
except Exception:
|
| 1713 |
self._psmiles_tokenizer = None
|
|
@@ -1734,7 +1735,6 @@ class PolymerOrchestrator:
|
|
| 1734 |
with torch.no_grad():
|
| 1735 |
embeddings_dict = self._cl_encoder.encode(batch_mods)
|
| 1736 |
|
| 1737 |
-
# enforce that all four modalities are present (gine, schnet, fp, psmiles)
|
| 1738 |
required_modalities = ("gine", "schnet", "fp", "psmiles")
|
| 1739 |
missing = [m for m in required_modalities if m not in embeddings_dict]
|
| 1740 |
if missing:
|
|
@@ -1757,8 +1757,8 @@ class PolymerOrchestrator:
|
|
| 1757 |
import torch.nn as nn
|
| 1758 |
|
| 1759 |
property_name = canonical_property_name(property_name)
|
| 1760 |
-
prop_ckpt = PROPERTY_HEAD_PATHS.get(property_name)
|
| 1761 |
-
prop_meta = PROPERTY_HEAD_META.get(property_name)
|
| 1762 |
|
| 1763 |
if prop_ckpt is None:
|
| 1764 |
raise ValueError(f"No property head registered for: {property_name}")
|
|
@@ -1778,7 +1778,6 @@ class PolymerOrchestrator:
|
|
| 1778 |
|
| 1779 |
ckpt = torch.load(prop_ckpt, map_location=self.config.device, weights_only=False)
|
| 1780 |
|
| 1781 |
-
# locate state dict
|
| 1782 |
state_dict = None
|
| 1783 |
for k in ("state_dict", "model_state_dict", "model_state", "head_state_dict", "regressor_state_dict"):
|
| 1784 |
if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
|
|
@@ -1804,7 +1803,6 @@ class PolymerOrchestrator:
|
|
| 1804 |
|
| 1805 |
head = RegressionHeadOnly(hidden_dim=600, dropout=float(meta.get("dropout", 0.1))).to(self.config.device)
|
| 1806 |
|
| 1807 |
-
# normalize key prefixes
|
| 1808 |
normalized = {}
|
| 1809 |
for k, v in state_dict.items():
|
| 1810 |
nk = k
|
|
@@ -1825,7 +1823,6 @@ class PolymerOrchestrator:
|
|
| 1825 |
head.load_state_dict(normalized, strict=False)
|
| 1826 |
head.eval()
|
| 1827 |
|
| 1828 |
-
# y scaler
|
| 1829 |
y_scaler = None
|
| 1830 |
if isinstance(ckpt, dict):
|
| 1831 |
for sk in ("y_scaler", "scaler_y", "target_scaler", "y_normalizer"):
|
|
@@ -1852,16 +1849,14 @@ class PolymerOrchestrator:
|
|
| 1852 |
return {"error": "Specify property name"}
|
| 1853 |
|
| 1854 |
property_name = canonical_property_name(property_name)
|
| 1855 |
-
if property_name not in PROPERTY_HEAD_PATHS:
|
| 1856 |
return {"error": f"Unsupported property: {property_name}"}
|
| 1857 |
|
| 1858 |
-
# Prefer embedding from cl_encoding output if available
|
| 1859 |
emb_from_cl = None
|
| 1860 |
cl = data.get("cl_encoding", None)
|
| 1861 |
if isinstance(cl, dict) and isinstance(cl.get("embedding"), list) and len(cl["embedding"]) == 600:
|
| 1862 |
emb_from_cl = torch.tensor([cl["embedding"]], dtype=torch.float32, device=self.config.device)
|
| 1863 |
|
| 1864 |
-
# If no embedding provided, compute via extraction + CL
|
| 1865 |
multimodal = data.get("data_extraction", None)
|
| 1866 |
psmiles = data.get("psmiles", data.get("smiles", None))
|
| 1867 |
if emb_from_cl is None:
|
|
@@ -1879,18 +1874,16 @@ class PolymerOrchestrator:
|
|
| 1879 |
with torch.no_grad():
|
| 1880 |
embs = self._cl_encoder.encode(batch_mods)
|
| 1881 |
|
| 1882 |
-
# enforce all four modalities
|
| 1883 |
required_modalities = ("gine", "schnet", "fp", "psmiles")
|
| 1884 |
missing = [m for m in required_modalities if m not in embs]
|
| 1885 |
if missing:
|
| 1886 |
return {"error": f"CL encoder did not return embeddings for modalities: {', '.join(missing)}"}
|
| 1887 |
|
| 1888 |
all_embs = [embs[k] for k in required_modalities]
|
| 1889 |
-
emb_from_cl = torch.stack(all_embs, dim=0).mean(dim=0)
|
| 1890 |
except Exception as e:
|
| 1891 |
return {"error": f"Failed to compute CL embedding: {e}"}
|
| 1892 |
|
| 1893 |
-
# Predict
|
| 1894 |
try:
|
| 1895 |
head, y_scaler, meta, ckpt_path = self._load_property_head(property_name)
|
| 1896 |
with torch.no_grad():
|
|
@@ -1898,27 +1891,21 @@ class PolymerOrchestrator:
|
|
| 1898 |
|
| 1899 |
pred_value = float(pred_norm)
|
| 1900 |
|
| 1901 |
-
# 1) Preferred: inverse_transform using the actual scaler object if available
|
| 1902 |
if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
|
| 1903 |
try:
|
| 1904 |
inv = y_scaler.inverse_transform(np.array([[pred_norm]], dtype=float))
|
| 1905 |
pred_value = float(inv[0][0])
|
| 1906 |
except Exception:
|
| 1907 |
pred_value = float(pred_norm)
|
| 1908 |
-
|
| 1909 |
-
# 2) Fallback: use metadata params if scaler object is missing
|
| 1910 |
else:
|
| 1911 |
mean = (meta or {}).get("scaler_mean", None)
|
| 1912 |
scale = (meta or {}).get("scaler_scale", None)
|
| 1913 |
-
|
| 1914 |
-
# StandardScaler inverse: x = x_scaled * scale + mean
|
| 1915 |
try:
|
| 1916 |
if isinstance(mean, list) and isinstance(scale, list) and len(mean) == 1 and len(scale) == 1:
|
| 1917 |
pred_value = float(pred_norm) * float(scale[0]) + float(mean[0])
|
| 1918 |
except Exception:
|
| 1919 |
pred_value = float(pred_norm)
|
| 1920 |
|
| 1921 |
-
# best-effort psmiles context
|
| 1922 |
out_psmiles = None
|
| 1923 |
if isinstance(multimodal, dict):
|
| 1924 |
out_psmiles = multimodal.get("canonical_psmiles")
|
|
@@ -1933,7 +1920,7 @@ class PolymerOrchestrator:
|
|
| 1933 |
"predictions": {property_name: pred_value},
|
| 1934 |
"prediction_normalized": float(pred_norm),
|
| 1935 |
"head_checkpoint_path": ckpt_path,
|
| 1936 |
-
"metadata_path": PROPERTY_HEAD_META.get(property_name, ""),
|
| 1937 |
"normalization_applied": bool(
|
| 1938 |
(y_scaler is not None and hasattr(y_scaler, "inverse_transform")) or
|
| 1939 |
((meta or {}).get("scaler_mean") is not None and (meta or {}).get("scaler_scale") is not None)
|
|
@@ -1943,11 +1930,8 @@ class PolymerOrchestrator:
|
|
| 1943 |
except Exception as e:
|
| 1944 |
return {"error": f"Property prediction failed: {e}"}
|
| 1945 |
|
| 1946 |
-
# ----------------- Inverse
|
| 1947 |
def _get_selfies_ted_backend(self, model_name: str) -> Tuple[Any, Any]:
|
| 1948 |
-
"""
|
| 1949 |
-
Cache and return (tokenizer, model) for a given SELFIES-TED model name.
|
| 1950 |
-
"""
|
| 1951 |
if not model_name:
|
| 1952 |
model_name = SELFIES_TED_MODEL_NAME
|
| 1953 |
if model_name in self._selfies_ted_cache:
|
|
@@ -1958,18 +1942,11 @@ class PolymerOrchestrator:
|
|
| 1958 |
return tok, model
|
| 1959 |
|
| 1960 |
def _load_property_generator(self, property_name: str):
|
| 1961 |
-
"""
|
| 1962 |
-
Load PolyBART-style inverse-design artifacts produced by G2.py:
|
| 1963 |
-
- decoder_best_fold*.pt : state_dict of CLConditionedSelfiesTEDGenerator
|
| 1964 |
-
- standardscaler_*.joblib : StandardScaler on property values
|
| 1965 |
-
- gpr_psmiles_*.joblib : LatentPropertyModel (z->property)
|
| 1966 |
-
- meta.json : meta info (selfies_ted_model, cl_emb_dim, mem_len, tol_scaled, ...)
|
| 1967 |
-
"""
|
| 1968 |
property_name = canonical_property_name(property_name)
|
| 1969 |
if property_name in self._property_generators:
|
| 1970 |
return self._property_generators[property_name]
|
| 1971 |
|
| 1972 |
-
base_dir = GENERATOR_DIRS.get(property_name)
|
| 1973 |
if base_dir is None:
|
| 1974 |
raise ValueError(f"No generator registered for: {property_name}")
|
| 1975 |
if not os.path.isdir(base_dir):
|
|
@@ -2017,12 +1994,10 @@ class PolymerOrchestrator:
|
|
| 2017 |
if not gpr_path or not os.path.exists(gpr_path):
|
| 2018 |
raise FileNotFoundError(f"GPR *.joblib not found in {base_dir}")
|
| 2019 |
|
| 2020 |
-
# Latent property model and scaler (G2-style LatentPropertyModel)
|
| 2021 |
_install_unpickle_shims()
|
| 2022 |
-
scaler_y = _safe_joblib_load(scaler_path)
|
| 2023 |
-
latent_prop_model = _safe_joblib_load(gpr_path)
|
| 2024 |
|
| 2025 |
-
# SELFIES-TED backbone
|
| 2026 |
selfies_ted_name = meta.get("selfies_ted_model", SELFIES_TED_MODEL_NAME)
|
| 2027 |
tok, selfies_backbone = self._get_selfies_ted_backend(selfies_ted_name)
|
| 2028 |
|
|
@@ -2037,7 +2012,6 @@ class PolymerOrchestrator:
|
|
| 2037 |
).to(self.config.device)
|
| 2038 |
|
| 2039 |
ckpt = torch.load(decoder_path, map_location=self.config.device, weights_only=False)
|
| 2040 |
-
# In G2, decoder_best_fold*.pt is a plain state_dict; keep robust fallback
|
| 2041 |
state_dict = None
|
| 2042 |
if isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
| 2043 |
state_dict = ckpt
|
|
@@ -2077,17 +2051,10 @@ class PolymerOrchestrator:
|
|
| 2077 |
latent_noise_std: float = LATENT_NOISE_STD_GEN,
|
| 2078 |
extra_factor: int = 8,
|
| 2079 |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
|
| 2080 |
-
"""
|
| 2081 |
-
Simple PolyBART-style latent sampler:
|
| 2082 |
-
- if seed_latents provided, sample Gaussian noise around them and L2-normalize
|
| 2083 |
-
- else, sample random latents on unit hypersphere
|
| 2084 |
-
- score via latent_prop_model (z->property), keep those near target.
|
| 2085 |
-
"""
|
| 2086 |
def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
| 2087 |
n = np.linalg.norm(x, axis=-1, keepdims=True)
|
| 2088 |
return x / np.clip(n, eps, None)
|
| 2089 |
|
| 2090 |
-
# target in scaled space
|
| 2091 |
if y_scaler is not None and hasattr(y_scaler, "transform"):
|
| 2092 |
target_s = float(y_scaler.transform(np.array([[target_value]], dtype=np.float32))[0, 0])
|
| 2093 |
else:
|
|
@@ -2127,23 +2094,14 @@ class PolymerOrchestrator:
|
|
| 2127 |
|
| 2128 |
@torch.no_grad()
|
| 2129 |
def _run_polymer_generation(self, step: Dict, data: Dict) -> Dict:
|
| 2130 |
-
"""
|
| 2131 |
-
Inverse-design generation (CL latent → pSELFIES via SELFIES-TED → PSMILES).
|
| 2132 |
-
|
| 2133 |
-
Corrections implemented:
|
| 2134 |
-
1) ONLY return RDKit-valid generated outputs (filter invalid candidates).
|
| 2135 |
-
2) Replace bracketed [At]/[AT]/[aT]/... with [*] AFTER the RDKit validity check
|
| 2136 |
-
but BEFORE writing the response payload.
|
| 2137 |
-
"""
|
| 2138 |
property_name = data.get("property", data.get("property_name", None))
|
| 2139 |
if property_name is None:
|
| 2140 |
return {"error": "Specify property name for generation"}
|
| 2141 |
|
| 2142 |
property_name = canonical_property_name(property_name)
|
| 2143 |
-
if property_name not in GENERATOR_DIRS:
|
| 2144 |
return {"error": f"Unsupported property: {property_name}"}
|
| 2145 |
|
| 2146 |
-
# STRICT: require target_value (support a few common aliases)
|
| 2147 |
if data.get("target_value", None) is not None:
|
| 2148 |
target_value = data["target_value"]
|
| 2149 |
elif data.get("target", None) is not None:
|
|
@@ -2177,14 +2135,12 @@ class PolymerOrchestrator:
|
|
| 2177 |
|
| 2178 |
latent_dim = int(getattr(decoder_model, "cl_emb_dim", 600))
|
| 2179 |
|
| 2180 |
-
# choose target scaler: prefer latent_prop_model.y_scaler, fall back to scaler_y
|
| 2181 |
y_scaler = getattr(latent_prop_model, "y_scaler", None)
|
| 2182 |
if y_scaler is None:
|
| 2183 |
y_scaler = scaler_y if scaler_y is not None else None
|
| 2184 |
|
| 2185 |
tol_scaled = float(tol_scaled_override) if tol_scaled_override is not None else float(meta.get("tol_scaled", 0.5))
|
| 2186 |
|
| 2187 |
-
# Collect seed latents from available sources:
|
| 2188 |
seed_latents: List[np.ndarray] = []
|
| 2189 |
cl_enc = data.get("cl_encoding", None)
|
| 2190 |
if isinstance(cl_enc, dict) and isinstance(cl_enc.get("embedding"), list):
|
|
@@ -2192,7 +2148,6 @@ class PolymerOrchestrator:
|
|
| 2192 |
if emb.shape[0] == latent_dim:
|
| 2193 |
seed_latents.append(emb)
|
| 2194 |
|
| 2195 |
-
# Optional seed pSMILES strings for biasing
|
| 2196 |
seeds_str: List[str] = []
|
| 2197 |
if isinstance(data.get("seed_psmiles_list"), list):
|
| 2198 |
seeds_str.extend([str(x) for x in data["seed_psmiles_list"] if isinstance(x, str)])
|
|
@@ -2200,10 +2155,8 @@ class PolymerOrchestrator:
|
|
| 2200 |
seeds_str.append(str(data["seed_psmiles"]))
|
| 2201 |
if data.get("psmiles") and not seeds_str:
|
| 2202 |
seeds_str.append(str(data["psmiles"]))
|
| 2203 |
-
|
| 2204 |
seeds_str = list(dict.fromkeys(seeds_str))
|
| 2205 |
|
| 2206 |
-
# If seed strings provided but no seed latents yet, compute CL embeddings for each seed
|
| 2207 |
if seeds_str and not seed_latents:
|
| 2208 |
self._ensure_cl_encoder()
|
| 2209 |
for s in seeds_str:
|
|
@@ -2216,7 +2169,6 @@ class PolymerOrchestrator:
|
|
| 2216 |
if z.shape[0] == latent_dim:
|
| 2217 |
seed_latents.append(z)
|
| 2218 |
|
| 2219 |
-
# Sample latents targeting the property
|
| 2220 |
try:
|
| 2221 |
Z_keep, y_s_keep, y_u_keep, target_s = self._sample_latents_for_target(
|
| 2222 |
latent_prop_model=latent_prop_model,
|
|
@@ -2232,7 +2184,6 @@ class PolymerOrchestrator:
|
|
| 2232 |
except Exception as e:
|
| 2233 |
return {"error": f"Failed to sample latents conditioned on property: {e}", "paths": paths}
|
| 2234 |
|
| 2235 |
-
# --- helpers ---
|
| 2236 |
at_bracket_re = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
|
| 2237 |
|
| 2238 |
def _at_to_star_bracket(s: str) -> str:
|
|
@@ -2241,7 +2192,6 @@ class PolymerOrchestrator:
|
|
| 2241 |
return at_bracket_re.sub("[*]", s)
|
| 2242 |
|
| 2243 |
def _is_rdkit_valid(psmiles: str) -> bool:
|
| 2244 |
-
# If RDKit is unavailable, we cannot validate; treat as "valid" but flag it below.
|
| 2245 |
if Chem is None:
|
| 2246 |
return True
|
| 2247 |
try:
|
|
@@ -2251,17 +2201,11 @@ class PolymerOrchestrator:
|
|
| 2251 |
except Exception:
|
| 2252 |
return False
|
| 2253 |
|
| 2254 |
-
# Decode latents → pSELFIES → PSMILES; filter to RDKit-valid ONLY.
|
| 2255 |
-
# Shortening strategy (3rd approach): generate MORE valid candidates, then keep the shortest valid K.
|
| 2256 |
requested_k = int(num_samples)
|
| 2257 |
-
|
| 2258 |
-
# candidates are tuples:
|
| 2259 |
-
# (len(psmiles), abs(y_s - target_s), psmiles_out, selfies_str, y_s, y_u)
|
| 2260 |
candidates: List[Tuple[int, float, str, str, float, float]] = []
|
| 2261 |
|
| 2262 |
-
# Reuse existing knob (extra_factor) to control "generate more" without adding new API surface.
|
| 2263 |
candidates_per_latent = max(1, int(extra_factor))
|
| 2264 |
-
max_gen_rounds = 4
|
| 2265 |
|
| 2266 |
Z_round, y_s_round, y_u_round = Z_keep, y_s_keep, y_u_keep
|
| 2267 |
for _round in range(max_gen_rounds):
|
|
@@ -2279,9 +2223,7 @@ class PolymerOrchestrator:
|
|
| 2279 |
for selfies_str in (outs or []):
|
| 2280 |
psm_raw = pselfies_to_psmiles(selfies_str)
|
| 2281 |
|
| 2282 |
-
# Correction #1: validate FIRST on the raw returned string
|
| 2283 |
if _is_rdkit_valid(psm_raw):
|
| 2284 |
-
# Correction #2: convert [At] -> [*] AFTER validation, BEFORE response writing
|
| 2285 |
psm_out = _at_to_star_bracket(psm_raw)
|
| 2286 |
candidates.append(
|
| 2287 |
(
|
|
@@ -2296,11 +2238,9 @@ class PolymerOrchestrator:
|
|
| 2296 |
except Exception:
|
| 2297 |
continue
|
| 2298 |
|
| 2299 |
-
# Stop early once we have enough valid candidates to select the shortest K.
|
| 2300 |
if len(candidates) >= requested_k:
|
| 2301 |
break
|
| 2302 |
|
| 2303 |
-
# If still short, resample latents and try again (best-effort; keeps validity constraints).
|
| 2304 |
try:
|
| 2305 |
Z_round, y_s_round, y_u_round, target_s = self._sample_latents_for_target(
|
| 2306 |
latent_prop_model=latent_prop_model,
|
|
@@ -2316,11 +2256,9 @@ class PolymerOrchestrator:
|
|
| 2316 |
except Exception:
|
| 2317 |
break
|
| 2318 |
|
| 2319 |
-
# Keep shortest valid K (tie-break by closeness to target in scaled space)
|
| 2320 |
candidates.sort(key=lambda t: (t[0], t[1]))
|
| 2321 |
selected = candidates[:requested_k]
|
| 2322 |
|
| 2323 |
-
# Ensure we return as many as requested when possible (repeat shortest valid if needed).
|
| 2324 |
if selected and len(selected) < requested_k:
|
| 2325 |
while len(selected) < requested_k:
|
| 2326 |
selected.append(selected[0])
|
|
@@ -2334,8 +2272,8 @@ class PolymerOrchestrator:
|
|
| 2334 |
"property": property_name,
|
| 2335 |
"target_value": float(target_value),
|
| 2336 |
"num_samples": int(len(generated_psmiles)),
|
| 2337 |
-
"generated_psmiles": generated_psmiles,
|
| 2338 |
-
"generated_selfies": selfies_raw,
|
| 2339 |
"latent_property_predictions": {
|
| 2340 |
"scaled": decoded_scaled,
|
| 2341 |
"unscaled": decoded_unscaled,
|
|
@@ -2386,11 +2324,11 @@ class PolymerOrchestrator:
|
|
| 2386 |
doi = normalize_doi(it.get("DOI", "")) or ""
|
| 2387 |
|
| 2388 |
publisher = (it.get("publisher") or "").lower()
|
| 2389 |
-
# Optional: exclude Brill explicitly
|
| 2390 |
if doi and doi.startswith("10.1163/"):
|
| 2391 |
continue
|
| 2392 |
if "brill" in publisher:
|
| 2393 |
continue
|
|
|
|
| 2394 |
pub_year = None
|
| 2395 |
if it.get("published-print") and isinstance(it["published-print"].get("date-parts"), list):
|
| 2396 |
pub_year = it["published-print"]["date-parts"][0][0]
|
|
@@ -2402,7 +2340,6 @@ class PolymerOrchestrator:
|
|
| 2402 |
doi = ""
|
| 2403 |
doi_url = ""
|
| 2404 |
|
| 2405 |
-
# Prefer DOI URL when valid; otherwise fall back to Crossref's URL field if present.
|
| 2406 |
landing = (it.get("URL") or "") if isinstance(it.get("URL"), str) else ""
|
| 2407 |
out.append({
|
| 2408 |
"title": title,
|
|
@@ -2433,7 +2370,6 @@ class PolymerOrchestrator:
|
|
| 2433 |
continue
|
| 2434 |
|
| 2435 |
doi = normalize_doi(it.get("doi", "")) or ""
|
| 2436 |
-
# Optional: exclude Brill explicitly
|
| 2437 |
if doi and doi.startswith("10.1163/"):
|
| 2438 |
continue
|
| 2439 |
|
|
@@ -2450,11 +2386,12 @@ class PolymerOrchestrator:
|
|
| 2450 |
|
| 2451 |
out.append({
|
| 2452 |
"title": it.get("title", ""),
|
| 2453 |
-
"doi": doi,
|
| 2454 |
-
"url": landing or "",
|
| 2455 |
"year": it.get("publication_year") or (it.get("publication_date", "")[:4]),
|
| 2456 |
"venue": (it.get("host_venue") or {}).get("display_name", ""),
|
| 2457 |
-
"type": oa_type,
|
|
|
|
| 2458 |
})
|
| 2459 |
return out
|
| 2460 |
except Exception as e:
|
|
@@ -2648,25 +2585,16 @@ class PolymerOrchestrator:
|
|
| 2648 |
return {"error": f"Unsupported web_search source: {src}"}
|
| 2649 |
|
| 2650 |
# =============================================================================
|
| 2651 |
-
# REPORT GENERATION
|
| 2652 |
# =============================================================================
|
| 2653 |
def generate_report(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 2654 |
-
"""
|
| 2655 |
-
Minimal, interface-safe report generator used by the Gradio UI fallback.
|
| 2656 |
-
|
| 2657 |
-
- Runs data_extraction -> cl_encoding -> property_prediction when possible
|
| 2658 |
-
- Optionally runs polymer_generation if generate=True or if target_value present
|
| 2659 |
-
- Optionally runs web_search if a query/literature_query is present
|
| 2660 |
-
"""
|
| 2661 |
payload = dict(data or {})
|
| 2662 |
summary: Dict[str, Any] = {}
|
| 2663 |
|
| 2664 |
-
# Seed psmiles/property
|
| 2665 |
prop = payload.get("property") or payload.get("property_name")
|
| 2666 |
if prop:
|
| 2667 |
payload["property"] = prop
|
| 2668 |
|
| 2669 |
-
# NEW: infer property from questions when missing
|
| 2670 |
if not payload.get("property"):
|
| 2671 |
qtxt = payload.get("questions") or payload.get("question") or ""
|
| 2672 |
inferred_prop = infer_property_from_text(qtxt)
|
|
@@ -2677,39 +2605,33 @@ class PolymerOrchestrator:
|
|
| 2677 |
if psmiles:
|
| 2678 |
payload["psmiles"] = psmiles
|
| 2679 |
|
| 2680 |
-
# NEW: infer target_value from questions when missing (only useful for generation)
|
| 2681 |
if payload.get("target_value", None) is None:
|
| 2682 |
qtxt = payload.get("questions") or payload.get("question") or ""
|
| 2683 |
inferred_tgt = infer_target_value_from_text(qtxt, payload.get("property"))
|
| 2684 |
if inferred_tgt is not None:
|
| 2685 |
payload["target_value"] = float(inferred_tgt)
|
| 2686 |
|
| 2687 |
-
# 1) data_extraction
|
| 2688 |
if psmiles and "data_extraction" not in payload:
|
| 2689 |
ex = self._run_data_extraction({"step": -1}, payload)
|
| 2690 |
payload["data_extraction"] = ex
|
| 2691 |
summary["data_extraction"] = ex
|
| 2692 |
|
| 2693 |
-
# 2) cl_encoding
|
| 2694 |
if "data_extraction" in payload and "cl_encoding" not in payload:
|
| 2695 |
cl = self._run_cl_encoding({"step": -1}, payload)
|
| 2696 |
payload["cl_encoding"] = cl
|
| 2697 |
summary["cl_encoding"] = cl
|
| 2698 |
|
| 2699 |
-
# 3) property_prediction
|
| 2700 |
if payload.get("property") and "property_prediction" not in payload:
|
| 2701 |
pp = self._run_property_prediction({"step": -1}, payload)
|
| 2702 |
payload["property_prediction"] = pp
|
| 2703 |
summary["property_prediction"] = pp
|
| 2704 |
|
| 2705 |
-
# 4) polymer_generation (optional)
|
| 2706 |
do_gen = bool(payload.get("generate", False)) or (payload.get("target_value", None) is not None)
|
| 2707 |
if do_gen and payload.get("property") and payload.get("target_value", None) is not None:
|
| 2708 |
gen = self._run_polymer_generation({"step": -1}, payload)
|
| 2709 |
payload["polymer_generation"] = gen
|
| 2710 |
summary["generation"] = gen
|
| 2711 |
|
| 2712 |
-
# 5) web_search (optional)
|
| 2713 |
q = payload.get("query") or payload.get("literature_query")
|
| 2714 |
src = payload.get("source") or "all"
|
| 2715 |
if q:
|
|
@@ -2730,7 +2652,6 @@ class PolymerOrchestrator:
|
|
| 2730 |
"questions": payload.get("questions") or payload.get("question") or "",
|
| 2731 |
}
|
| 2732 |
|
| 2733 |
-
# Add domain tags + (domain.com) cite tags, and tool tags [T#]
|
| 2734 |
report = _attach_source_domains(report)
|
| 2735 |
report = _index_citable_sources(report)
|
| 2736 |
report = _assign_tool_tags_to_report(report)
|
|
@@ -2740,32 +2661,23 @@ class PolymerOrchestrator:
|
|
| 2740 |
def _run_report_generation(self, step: Dict, data: Dict) -> Dict[str, Any]:
|
| 2741 |
return self.generate_report(data)
|
| 2742 |
|
|
|
|
|
|
|
|
|
|
| 2743 |
def compose_gpt_style_answer(
|
| 2744 |
self,
|
| 2745 |
report: Dict[str, Any],
|
| 2746 |
case_brief: str = "",
|
| 2747 |
questions: str = "",
|
| 2748 |
) -> Tuple[str, List[str]]:
|
| 2749 |
-
"""
|
| 2750 |
-
Interface-safe composer. Uses OpenAI if available; otherwise returns a deterministic markdown.
|
| 2751 |
-
Must return: (final_markdown, list_of_image_paths).
|
| 2752 |
-
|
| 2753 |
-
Updated requirements:
|
| 2754 |
-
- No fixed answer template: structure must follow the user's actual questions.
|
| 2755 |
-
- Literature/web citations must be domain-style like "nature.com" (never [1], [2], ...). No parentheses.
|
| 2756 |
-
- Tool-derived facts must cite as [T] only.
|
| 2757 |
-
- Tool outputs should be available verbatim without tweaking (appended as JSON blocks).
|
| 2758 |
-
"""
|
| 2759 |
imgs: List[str] = []
|
| 2760 |
|
| 2761 |
-
# Ensure tags exist even if caller didn't run generate_report()
|
| 2762 |
if isinstance(report, dict):
|
| 2763 |
report = _attach_source_domains(report)
|
| 2764 |
report = _index_citable_sources(report)
|
| 2765 |
report = _assign_tool_tags_to_report(report)
|
| 2766 |
|
| 2767 |
if self.openai_client is None:
|
| 2768 |
-
# Deterministic fallback (no API dependency)
|
| 2769 |
md_lines = []
|
| 2770 |
if case_brief:
|
| 2771 |
md_lines.append(case_brief.strip())
|
|
@@ -2780,7 +2692,6 @@ class PolymerOrchestrator:
|
|
| 2780 |
md_lines.append(str(report))
|
| 2781 |
md_lines.append("```")
|
| 2782 |
|
| 2783 |
-
# Verbatim tool outputs (no tweaking)
|
| 2784 |
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
|
| 2785 |
if verb:
|
| 2786 |
md_lines.append("\n---\n\n## Tool outputs (verbatim)\n")
|
|
@@ -2788,7 +2699,6 @@ class PolymerOrchestrator:
|
|
| 2788 |
|
| 2789 |
return "\n".join(md_lines), imgs
|
| 2790 |
|
| 2791 |
-
# OpenAI-based synthesis
|
| 2792 |
try:
|
| 2793 |
prompt = (
|
| 2794 |
"You are PolyAgent - consider yourself as an expert in polymer science. Answer the user's questions using ONLY the provided report.\n"
|
|
@@ -2804,12 +2714,11 @@ class PolymerOrchestrator:
|
|
| 2804 |
"- NON-DUPLICATES: Do not repeat the same paper link. Each DOI/URL may appear at most once in the entire answer.\n"
|
| 2805 |
"- Each major section should include at least 1 inline literature citation when relevant.\n"
|
| 2806 |
"- Do NOT invent DOIs, URLs, titles, or sources.\n\n"
|
| 2807 |
-
"- CITATIONS AS SPECIFIED ONLY: very strictly place each citation immediately after the claim it supports; do not add a references list.\n"
|
| 2808 |
"OUTPUT RULES (STRICT):\n"
|
| 2809 |
"- If a numeric value is not present in the report, write 'not available'.\n"
|
| 2810 |
"- Preserve polymer endpoint tokens exactly as '[*]' in any pSMILES/SMILES shown.\n"
|
| 2811 |
"- To prevent markdown mangling, put any pSMILES/SMILES inside code formatting.\n"
|
| 2812 |
-
"- Do not rewrite or tweak any tool outputs; if you refer to them, reference them by tag (e.g., [
|
| 2813 |
f"CASE BRIEF:\n{case_brief}\n\n"
|
| 2814 |
f"QUESTIONS:\n{questions}\n\n"
|
| 2815 |
f"REPORT (JSON):\n{json.dumps(report, ensure_ascii=False)}\n"
|
|
@@ -2825,16 +2734,12 @@ class PolymerOrchestrator:
|
|
| 2825 |
)
|
| 2826 |
txt = resp.choices[0].message.content or ""
|
| 2827 |
|
| 2828 |
-
# Enforce distributed inline clickable paper citations (do not touch tool citations).
|
| 2829 |
-
# This corrects cases where the model under-cites or clusters citations.
|
| 2830 |
try:
|
| 2831 |
min_cites = _infer_required_citation_count(questions or "", default_n=10)
|
| 2832 |
txt = _ensure_distributed_inline_citations(txt, report, min_needed=min_cites)
|
| 2833 |
except Exception:
|
| 2834 |
pass
|
| 2835 |
|
| 2836 |
-
|
| 2837 |
-
# Enforce: DOI-URL bracket text + dedupe (each DOI/URL appears at most once)
|
| 2838 |
try:
|
| 2839 |
txt = _normalize_and_dedupe_literature_links(txt, report)
|
| 2840 |
except Exception:
|
|
@@ -2845,23 +2750,20 @@ class PolymerOrchestrator:
|
|
| 2845 |
except Exception:
|
| 2846 |
pass
|
| 2847 |
|
| 2848 |
-
# Always append verbatim tool outputs (no tweaking)
|
| 2849 |
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
|
| 2850 |
if verb:
|
| 2851 |
txt = txt.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
|
| 2852 |
|
| 2853 |
return txt, imgs
|
| 2854 |
except Exception as e:
|
| 2855 |
-
# Last-resort fallback
|
| 2856 |
md = f"OpenAI compose failed: {e}\n\n```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```"
|
| 2857 |
-
# Still append verbatim tool outputs
|
| 2858 |
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
|
| 2859 |
if verb:
|
| 2860 |
md = md.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
|
| 2861 |
return md, imgs
|
| 2862 |
|
| 2863 |
# =============================================================================
|
| 2864 |
-
# VISUAL TOOLS
|
| 2865 |
# =============================================================================
|
| 2866 |
def _run_mol_render(self, step: Dict, data: Dict) -> Dict[str, Any]:
|
| 2867 |
out_dir = Path("viz")
|
|
@@ -2915,16 +2817,6 @@ class PolymerOrchestrator:
|
|
| 2915 |
return {"png_path": png, "n": len(mols)}
|
| 2916 |
|
| 2917 |
def _run_prop_attribution(self, step: Dict, data: Dict) -> Dict[str, Any]:
|
| 2918 |
-
"""
|
| 2919 |
-
FIXED explainability:
|
| 2920 |
-
- Leave-one-atom-out occlusion attribution:
|
| 2921 |
-
score_i = baseline_pred - pred(mask atom i -> wildcard)
|
| 2922 |
-
- Highlight ONLY meaningful atoms:
|
| 2923 |
-
* Rank by |score|
|
| 2924 |
-
* Apply relative threshold vs max |score| (default 0.25)
|
| 2925 |
-
* Cap by top-K
|
| 2926 |
-
* Ensure at least 1 atom highlighted
|
| 2927 |
-
"""
|
| 2928 |
out_dir = Path("viz")
|
| 2929 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 2930 |
|
|
@@ -2935,11 +2827,10 @@ class PolymerOrchestrator:
|
|
| 2935 |
prop = canonical_property_name(data.get("property") or data.get("property_name") or "glass transition")
|
| 2936 |
top_k = int(data.get("top_k_atoms", data.get("top_k", 12)))
|
| 2937 |
|
| 2938 |
-
# importance threshold controls
|
| 2939 |
min_rel_importance = float(data.get("min_rel_importance", 0.25))
|
| 2940 |
min_abs_importance = float(data.get("min_abs_importance", 0.0))
|
| 2941 |
|
| 2942 |
-
if prop not in PROPERTY_HEAD_PATHS:
|
| 2943 |
return {"error": f"Unsupported property for attribution: {prop}"}
|
| 2944 |
if not p:
|
| 2945 |
return {"error": "no psmiles"}
|
|
@@ -2960,7 +2851,6 @@ class PolymerOrchestrator:
|
|
| 2960 |
if not isinstance(baseline, (float, int)):
|
| 2961 |
return {"error": "Baseline prediction not numeric"}
|
| 2962 |
|
| 2963 |
-
# Occlusion loop (O(N_atoms) property predictions)
|
| 2964 |
scores: Dict[int, float] = {}
|
| 2965 |
for idx in range(num_atoms):
|
| 2966 |
try:
|
|
@@ -2968,7 +2858,7 @@ class PolymerOrchestrator:
|
|
| 2968 |
tmp.GetAtomWithIdx(idx).SetAtomicNum(0) # wildcard
|
| 2969 |
mutated = tmp.GetMol()
|
| 2970 |
mut_smiles = Chem.MolToSmiles(mutated)
|
| 2971 |
-
mut_psmiles = normalize_generated_psmiles_out(mut_smiles)
|
| 2972 |
except Exception:
|
| 2973 |
scores[idx] = 0.0
|
| 2974 |
continue
|
|
@@ -2980,7 +2870,6 @@ class PolymerOrchestrator:
|
|
| 2980 |
else:
|
| 2981 |
scores[idx] = float(baseline) - float(mut_val)
|
| 2982 |
|
| 2983 |
-
# Select atoms: top-K by |score| but also require significance
|
| 2984 |
max_abs = max((abs(v) for v in scores.values()), default=0.0)
|
| 2985 |
rel_thresh = (min_rel_importance * max_abs) if max_abs > 0 else 0.0
|
| 2986 |
thresh = max(float(min_abs_importance), float(rel_thresh))
|
|
@@ -2991,11 +2880,9 @@ class PolymerOrchestrator:
|
|
| 2991 |
selected = [i for i, v in ranked if abs(v) >= thresh]
|
| 2992 |
selected = selected[:k_cap]
|
| 2993 |
|
| 2994 |
-
# Ensure at least one highlighted atom
|
| 2995 |
if not selected and ranked:
|
| 2996 |
selected = [ranked[0][0]]
|
| 2997 |
|
| 2998 |
-
# Map colors (coolwarm) over selected only
|
| 2999 |
atom_colors: Dict[int, tuple] = {}
|
| 3000 |
sel_scores = np.array([scores[i] for i in selected], dtype=float)
|
| 3001 |
if cm is not None and sel_scores.size > 0:
|
|
@@ -3043,7 +2930,6 @@ class PolymerOrchestrator:
|
|
| 3043 |
except Exception as e:
|
| 3044 |
return {"error": f"prop_attribution rendering failed: {e}"}
|
| 3045 |
|
| 3046 |
-
# convenience
|
| 3047 |
def process_query(self, user_query: str, user_inputs: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 3048 |
plan = self.analyze_query(user_query)
|
| 3049 |
results = self.execute_plan(plan, user_inputs)
|
|
@@ -3051,6 +2937,6 @@ class PolymerOrchestrator:
|
|
| 3051 |
|
| 3052 |
|
| 3053 |
if __name__ == "__main__":
|
| 3054 |
-
cfg = OrchestratorConfig()
|
| 3055 |
orch = PolymerOrchestrator(cfg)
|
| 3056 |
-
print("PolymerOrchestrator ready (5M heads + 5M inverse-design + LLM
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PolyAgent Orchestrator (5M)
|
| 3 |
+
===========================
|
| 4 |
+
|
| 5 |
+
This file provides a modular orchestrator that:
|
| 6 |
+
- extracts polymer multimodal data (graph/geometry/fingerprints/PSMILES)
|
| 7 |
+
- encodes CL embeddings using PolyFusion encoders
|
| 8 |
+
- predicts single properties using best downstream heads
|
| 9 |
+
- performs inverse design using a CL-conditioned SELFIES-TED generator
|
| 10 |
+
- retrieves literature via local RAG + web APIs
|
| 11 |
+
- visualizes polymer renderings and explainability maps
|
| 12 |
+
- composes a final response along with verbatim tool outputs
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
import os
|
| 16 |
import re
|
| 17 |
import json
|
|
|
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Dict, Any, List, Optional, Tuple
|
| 22 |
from urllib.parse import urlparse
|
| 23 |
+
|
| 24 |
import numpy as np
|
| 25 |
import torch
|
| 26 |
import torch.nn as nn
|
|
|
|
| 61 |
except Exception:
|
| 62 |
spm = None
|
| 63 |
|
| 64 |
+
# Optional: selfies (for SELFIES→SMILES/PSMILES conversion)
|
| 65 |
try:
|
| 66 |
import selfies as sf
|
| 67 |
except Exception:
|
|
|
|
| 71 |
SELFIES_AVAILABLE = sf is not None
|
| 72 |
|
| 73 |
|
| 74 |
+
# =============================================================================
|
| 75 |
+
# PATHS / CONFIGURATION
|
| 76 |
+
# =============================================================================
|
| 77 |
+
class PathsConfig:
|
| 78 |
+
"""
|
| 79 |
+
Centralized path placeholders. Replace these with your local paths.
|
| 80 |
+
"""
|
| 81 |
+
# CL weights
|
| 82 |
+
cl_weights_path = "/path/to/multimodal_output_5M/best/pytorch_model.bin"
|
| 83 |
+
|
| 84 |
+
# Chroma DB (local RAG vectorstore persist dir)
|
| 85 |
+
chroma_db_path = "/path/to/chroma_polymer_db_big"
|
| 86 |
+
|
| 87 |
+
# SentencePiece model + vocab
|
| 88 |
+
spm_model_path = "/path/to/spm_5M.model"
|
| 89 |
+
spm_vocab_path = "/path/to/spm_5M.vocab"
|
| 90 |
+
|
| 91 |
+
# Downstream bestweights directory produced by your 5M downstream script
|
| 92 |
+
downstream_bestweights_5m_dir = "/path/to/multimodal_downstream_bestweights_5M"
|
| 93 |
+
|
| 94 |
+
# Inverse-design generator artifacts directory produced by your 5M inverse design run
|
| 95 |
+
inverse_design_5m_dir = "/path/to/multimodal_inverse_design_output_5M_polybart_style/best_models"
|
| 96 |
+
|
| 97 |
+
|
| 98 |
# =============================================================================
|
| 99 |
# DOI NORMALIZATION / RESOLUTION HELPERS
|
| 100 |
# =============================================================================
|
| 101 |
_DOI_RE = re.compile(r"^10\.\d{4,9}/\S+$", re.IGNORECASE)
|
| 102 |
|
| 103 |
+
|
| 104 |
def normalize_doi(raw: str) -> Optional[str]:
|
| 105 |
if not isinstance(raw, str):
|
| 106 |
return None
|
|
|
|
| 114 |
s = s.rstrip(").,;]}")
|
| 115 |
return s if _DOI_RE.match(s) else None
|
| 116 |
|
| 117 |
+
|
| 118 |
def doi_to_url(doi: str) -> str:
|
| 119 |
# doi is assumed normalized
|
| 120 |
return f"https://doi.org/{doi}"
|
| 121 |
|
| 122 |
+
|
| 123 |
def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
|
| 124 |
"""
|
| 125 |
Best-effort resolver check. Keeps pipeline robust against dead/unregistered DOIs.
|
|
|
|
| 136 |
except Exception:
|
| 137 |
return False
|
| 138 |
|
| 139 |
+
|
| 140 |
# =============================================================================
|
| 141 |
+
# CITATION / DOMAIN TAGGING HELPERS
|
| 142 |
# =============================================================================
|
| 143 |
def _url_to_domain(url: str) -> Optional[str]:
|
| 144 |
if not isinstance(url, str) or not url.strip():
|
|
|
|
| 175 |
except Exception:
|
| 176 |
return None
|
| 177 |
|
| 178 |
+
|
| 179 |
def _attach_source_domains(obj: Any) -> Any:
|
| 180 |
"""
|
| 181 |
Recursively add a short source_domain field where URLs are present.
|
| 182 |
+
This enables domain-style citations like "(nature.com)" (note: the composer
|
| 183 |
+
later enforces DOI-URL bracket citations for papers).
|
| 184 |
"""
|
| 185 |
if isinstance(obj, list):
|
| 186 |
return [_attach_source_domains(x) for x in obj]
|
|
|
|
| 204 |
def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
|
| 205 |
"""
|
| 206 |
Add 'cite_tag' fields for citable web/RAG items using DOI-first URL tags.
|
| 207 |
+
|
| 208 |
Requirement:
|
| 209 |
- Paper citations must use the COMPLETE DOI URL (https://doi.org/...) as the bracket text.
|
| 210 |
- If DOI is not available, fall back to the best http(s) URL.
|
|
|
|
| 226 |
return False
|
| 227 |
|
| 228 |
def get_best_url(d: Dict[str, Any]) -> Optional[str]:
|
| 229 |
+
# DOI-first
|
| 230 |
doi = normalize_doi(d.get("doi", ""))
|
| 231 |
if doi:
|
| 232 |
return doi_to_url(doi)
|
|
|
|
| 244 |
out = {k: walk_and_tag(v) for k, v in node.items()}
|
| 245 |
|
| 246 |
if is_citable_item(out):
|
|
|
|
| 247 |
url = get_best_url(out)
|
| 248 |
if isinstance(url, str) and url.startswith(("http://", "https://")):
|
|
|
|
| 249 |
cur = out.get("cite_tag")
|
| 250 |
if not (isinstance(cur, str) and cur.strip().startswith(("http://", "https://"))):
|
| 251 |
out["cite_tag"] = url.strip()
|
|
|
|
|
|
|
|
|
|
| 252 |
|
|
|
|
| 253 |
url = get_best_url(out)
|
| 254 |
dom = out.get("source_domain") or (_url_to_domain(url) if url else None) or "source"
|
| 255 |
citation_index["sources"].append(
|
| 256 |
{
|
|
|
|
| 257 |
"tag": out.get("cite_tag") if isinstance(out.get("cite_tag"), str) else url,
|
| 258 |
"domain": dom,
|
| 259 |
"title": out.get("title") or out.get("name") or "Untitled",
|
|
|
|
| 274 |
return report
|
| 275 |
|
| 276 |
|
| 277 |
+
# =============================================================================
|
| 278 |
+
# INLINE CITATION ENFORCERS (distributed, deduped, DOI-first)
|
| 279 |
+
# =============================================================================
|
|
|
|
| 280 |
_CITE_COUNT_PATTERNS = [
|
| 281 |
r"(?:at\s+least\s+)?(\d{1,3})\s*(?:citations|citation|papers|paper|sources|source|references|reference)\b",
|
| 282 |
r"\bcite\s+(\d{1,3})\s*(?:papers|paper|sources|source|references|reference|citations|citation)\b",
|
|
|
|
| 299 |
|
| 300 |
def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[str, str]]:
|
| 301 |
"""
|
| 302 |
+
Return unique (cite_text, url) pairs from report['citation_index']['sources'].
|
| 303 |
+
cite_text is strictly the DOI URL (preferred) or URL fallback.
|
| 304 |
"""
|
| 305 |
out: List[Tuple[str, str]] = []
|
| 306 |
seen: set = set()
|
|
|
|
| 317 |
url = s.get("url")
|
| 318 |
if not isinstance(url, str) or not url.startswith(("http://", "https://")):
|
| 319 |
continue
|
|
|
|
| 320 |
cite_text = s.get("tag") if isinstance(s.get("tag"), str) and s.get("tag").strip() else url
|
| 321 |
if not isinstance(cite_text, str) or not cite_text.strip():
|
| 322 |
cite_text = url
|
|
|
|
| 346 |
if not citations:
|
| 347 |
return md
|
| 348 |
|
|
|
|
| 349 |
existing_urls = set(re.findall(r"\[[^\]]+\]\((https?://[^)]+)\)", md))
|
| 350 |
need = max(0, int(min_needed) - len(existing_urls))
|
| 351 |
if need <= 0:
|
| 352 |
return md
|
| 353 |
|
|
|
|
| 354 |
remaining: List[Tuple[str, str]] = [(d, u) for (d, u) in citations if u not in existing_urls]
|
| 355 |
if not remaining:
|
| 356 |
return md
|
| 357 |
|
|
|
|
| 358 |
parts = re.split(r"(```[\s\S]*?```)", md)
|
| 359 |
rem_i = 0
|
| 360 |
|
|
|
|
| 364 |
if part.startswith("```") and part.endswith("```"):
|
| 365 |
continue
|
| 366 |
|
|
|
|
| 367 |
segs = re.split(r"(\n\s*\n)", part)
|
| 368 |
for si in range(0, len(segs), 2):
|
| 369 |
if rem_i >= len(remaining) or need <= 0:
|
|
|
|
| 371 |
para = segs[si]
|
| 372 |
if not isinstance(para, str) or not para.strip():
|
| 373 |
continue
|
|
|
|
| 374 |
if para.lstrip().startswith("#"):
|
| 375 |
continue
|
|
|
|
| 376 |
if re.search(r"\[[^\]]+\]\((https?://[^)]+)\)", para):
|
| 377 |
continue
|
| 378 |
+
if not re.search(
|
| 379 |
+
r"\b(reported|shown|demonstrated|study|studies|literature|evidence|review|according)\b",
|
| 380 |
+
para,
|
| 381 |
+
flags=re.IGNORECASE,
|
| 382 |
+
):
|
| 383 |
continue
|
| 384 |
|
| 385 |
cite_text, url = remaining[rem_i]
|
|
|
|
| 386 |
segs[si] = para.rstrip() + f" [{cite_text}]({url})"
|
| 387 |
rem_i += 1
|
| 388 |
need -= 1
|
| 389 |
|
| 390 |
parts[pi] = "".join(segs)
|
| 391 |
|
|
|
|
| 392 |
if need > 0 and rem_i < len(remaining):
|
| 393 |
md2 = "".join(parts)
|
| 394 |
parts2 = re.split(r"(```[\s\S]*?```)", md2)
|
|
|
|
| 420 |
|
| 421 |
def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> str:
|
| 422 |
"""
|
| 423 |
+
Enforce:
|
| 424 |
+
- Link text must be COMPLETE DOI URL (preferred) or URL fallback.
|
| 425 |
+
- Each DOI/URL appears at most once in the answer.
|
| 426 |
Only operates outside fenced code blocks.
|
| 427 |
"""
|
| 428 |
if not isinstance(md, str) or not md.strip():
|
|
|
|
| 430 |
if not isinstance(report, dict):
|
| 431 |
return md
|
| 432 |
|
|
|
|
| 433 |
url_to_text: Dict[str, str] = {}
|
| 434 |
ci = report.get("citation_index", {})
|
| 435 |
sources = ci.get("sources") if isinstance(ci, dict) else None
|
|
|
|
| 449 |
|
| 450 |
def _rewrite_and_dedupe(text: str) -> str:
|
| 451 |
def repl(m: re.Match) -> str:
|
|
|
|
| 452 |
url = m.group(2).strip()
|
| 453 |
if url in seen_urls:
|
|
|
|
| 454 |
return ""
|
| 455 |
seen_urls.add(url)
|
| 456 |
pref = url_to_text.get(url, url)
|
| 457 |
return f"[{pref}]({url})"
|
| 458 |
+
|
| 459 |
return re.sub(r"\[([^\]]+)\]\((https?://[^)]+)\)", repl, text)
|
| 460 |
|
| 461 |
for i, part in enumerate(parts):
|
| 462 |
if part.startswith("```") and part.endswith("```"):
|
| 463 |
continue
|
| 464 |
parts[i] = _rewrite_and_dedupe(part)
|
|
|
|
| 465 |
parts[i] = re.sub(r"[ \t]{2,}", " ", parts[i])
|
| 466 |
parts[i] = re.sub(r"\n{3,}", "\n\n", parts[i])
|
| 467 |
|
|
|
|
| 471 |
def autolink_doi_urls(md: str) -> str:
|
| 472 |
"""
|
| 473 |
Wrap bare DOI URLs in Markdown links outside code blocks.
|
|
|
|
| 474 |
"""
|
| 475 |
if not md:
|
| 476 |
return md
|
|
|
|
| 481 |
parts[i] = re.sub(
|
| 482 |
r"(?<!\]\()(?P<u>https?://doi\.org/10\.\d{4,9}/[^\s\)\],;]+)",
|
| 483 |
lambda m: f"[{m.group('u')}]({m.group('u')})",
|
| 484 |
+
part,
|
| 485 |
flags=re.IGNORECASE,
|
| 486 |
)
|
| 487 |
return "".join(parts)
|
| 488 |
|
| 489 |
+
|
| 490 |
+
# =============================================================================
|
| 491 |
+
# TOOL TAGS + VERBATIM TOOL OUTPUT RENDERER
|
| 492 |
+
# =============================================================================
|
| 493 |
def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
|
| 494 |
"""
|
| 495 |
+
Ensure each tool output has a [T] cite tag.
|
|
|
|
| 496 |
"""
|
| 497 |
if not isinstance(report, dict):
|
| 498 |
return report
|
|
|
|
| 501 |
if not isinstance(tool_outputs, dict):
|
| 502 |
return report
|
| 503 |
|
| 504 |
+
preferred = [
|
|
|
|
| 505 |
"data_extraction",
|
| 506 |
"cl_encoding",
|
| 507 |
"property_prediction",
|
|
|
|
| 511 |
"report_generation",
|
| 512 |
]
|
| 513 |
|
|
|
|
| 514 |
tool_tag_map: Dict[str, str] = {}
|
| 515 |
tag = "[T]"
|
| 516 |
|
| 517 |
+
for tool in preferred:
|
|
|
|
| 518 |
node = tool_outputs.get(tool)
|
| 519 |
if node is None:
|
| 520 |
continue
|
|
|
|
| 522 |
if isinstance(node, dict) and not node.get("cite_tag"):
|
| 523 |
node["cite_tag"] = tag
|
| 524 |
|
|
|
|
| 525 |
for tool, node in tool_outputs.items():
|
| 526 |
if tool in tool_tag_map or node is None:
|
| 527 |
continue
|
|
|
|
| 529 |
if isinstance(node, dict) and not node.get("cite_tag"):
|
| 530 |
node["cite_tag"] = tag
|
| 531 |
|
|
|
|
| 532 |
try:
|
| 533 |
summary = report.get("summary", {}) or {}
|
| 534 |
if isinstance(summary, dict):
|
|
|
|
| 535 |
key_to_tool = {
|
| 536 |
"data_extraction": "data_extraction",
|
| 537 |
"cl_encoding": "cl_encoding",
|
|
|
|
| 555 |
|
| 556 |
def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
|
| 557 |
"""
|
| 558 |
+
Render tool outputs as verbatim JSON blocks (no content rewriting).
|
| 559 |
"""
|
| 560 |
if not isinstance(report, dict):
|
| 561 |
return ""
|
|
|
|
| 564 |
if not isinstance(tool_outputs, dict):
|
| 565 |
return ""
|
| 566 |
|
|
|
|
| 567 |
preferred = [
|
| 568 |
"data_extraction",
|
| 569 |
"cl_encoding",
|
|
|
|
| 591 |
|
| 592 |
|
| 593 |
# =============================================================================
|
| 594 |
+
# PICKLE / JOBLIB COMPATIBILITY SHIMS
|
| 595 |
# =============================================================================
|
| 596 |
class LatentPropertyModel:
|
| 597 |
"""
|
| 598 |
Compatibility shim for joblib/pickle artifacts saved with references like:
|
| 599 |
__main__.LatentPropertyModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
"""
|
| 601 |
def predict(self, X):
|
|
|
|
| 602 |
for attr in ("model", "gpr", "gpr_model", "estimator", "predictor", "_model", "_gpr"):
|
| 603 |
if hasattr(self, attr):
|
| 604 |
obj = getattr(self, attr)
|
| 605 |
if hasattr(obj, "predict"):
|
| 606 |
return obj.predict(X)
|
|
|
|
| 607 |
raise AttributeError(
|
| 608 |
"LatentPropertyModel shim could not find an underlying predictor. "
|
| 609 |
"Artifact expects a wrapped model attribute with a .predict method."
|
|
|
|
| 613 |
def _install_unpickle_shims() -> None:
|
| 614 |
"""
|
| 615 |
Ensure that any classes pickled under __main__ are available at load time.
|
|
|
|
| 616 |
"""
|
| 617 |
main_mod = sys.modules.get("__main__")
|
| 618 |
if main_mod is not None and not hasattr(main_mod, "LatentPropertyModel"):
|
|
|
|
| 630 |
return joblib.load(path)
|
| 631 |
except Exception as e:
|
| 632 |
msg = str(e)
|
|
|
|
| 633 |
if "Can't get attribute 'LatentPropertyModel' on <module '__main__'" in msg:
|
| 634 |
_install_unpickle_shims()
|
| 635 |
return joblib.load(path)
|
|
|
|
| 637 |
|
| 638 |
|
| 639 |
# =============================================================================
|
| 640 |
+
# PROPERTY + GENERATOR REGISTRY
|
| 641 |
# =============================================================================
|
| 642 |
+
def build_property_registries(paths: PathsConfig):
|
| 643 |
+
"""
|
| 644 |
+
Build registry dicts for:
|
| 645 |
+
- downstream property heads (checkpoint + metadata)
|
| 646 |
+
- inverse-design generator directories
|
| 647 |
+
"""
|
| 648 |
+
downstream = paths.downstream_bestweights_5m_dir
|
| 649 |
+
invgen = paths.inverse_design_5m_dir
|
| 650 |
+
|
| 651 |
+
PROPERTY_HEAD_PATHS = {
|
| 652 |
+
"density": os.path.join(downstream, "density", "best_run_checkpoint.pt"),
|
| 653 |
+
"glass transition": os.path.join(downstream, "glass_transition", "best_run_checkpoint.pt"),
|
| 654 |
+
"melting": os.path.join(downstream, "melting", "best_run_checkpoint.pt"),
|
| 655 |
+
"specific volume": os.path.join(downstream, "specific_volume", "best_run_checkpoint.pt"),
|
| 656 |
+
"thermal decomposition": os.path.join(downstream, "thermal_decomposition", "best_run_checkpoint.pt"),
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
PROPERTY_HEAD_META = {
|
| 660 |
+
"density": os.path.join(downstream, "density", "best_run_metadata.json"),
|
| 661 |
+
"glass transition": os.path.join(downstream, "glass_transition", "best_run_metadata.json"),
|
| 662 |
+
"melting": os.path.join(downstream, "melting", "best_run_metadata.json"),
|
| 663 |
+
"specific volume": os.path.join(downstream, "specific_volume", "best_run_metadata.json"),
|
| 664 |
+
"thermal decomposition": os.path.join(downstream, "thermal_decomposition", "best_run_metadata.json"),
|
| 665 |
+
}
|
| 666 |
+
|
| 667 |
+
GENERATOR_DIRS = {
|
| 668 |
+
"density": os.path.join(invgen, "density"),
|
| 669 |
+
"glass transition": os.path.join(invgen, "glass_transition"),
|
| 670 |
+
"melting": os.path.join(invgen, "melting"),
|
| 671 |
+
"specific volume": os.path.join(invgen, "specific_volume"),
|
| 672 |
+
"thermal decomposition": os.path.join(invgen, "thermal_decomposition"),
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
return PROPERTY_HEAD_PATHS, PROPERTY_HEAD_META, GENERATOR_DIRS
|
| 676 |
|
| 677 |
|
| 678 |
# =============================================================================
|
| 679 |
+
# Property name canonicalization + inference helpers
|
| 680 |
# =============================================================================
|
| 681 |
def canonical_property_name(name: str) -> str:
|
| 682 |
"""
|
| 683 |
+
Map user/tool inputs to the canonical keys used in registries.
|
| 684 |
"""
|
| 685 |
if not isinstance(name, str):
|
| 686 |
return ""
|
|
|
|
| 704 |
return aliases.get(s, s)
|
| 705 |
|
| 706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 707 |
_NUM_RE = r"[-+]?\d+(?:\.\d+)?"
|
| 708 |
|
| 709 |
+
|
| 710 |
def infer_property_from_text(text: str) -> Optional[str]:
|
| 711 |
s = (text or "").lower()
|
|
|
|
| 712 |
m = re.search(r"\bproperty\b\s*[:=]\s*([a-zA-Z _-]+)", s)
|
| 713 |
if m:
|
| 714 |
cand = m.group(1).strip().lower()
|
|
|
|
| 735 |
return "density"
|
| 736 |
return None
|
| 737 |
|
| 738 |
+
|
| 739 |
def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[float]:
|
| 740 |
sl = (text or "").lower()
|
| 741 |
|
|
|
|
| 767 |
except Exception:
|
| 768 |
pass
|
| 769 |
|
|
|
|
| 770 |
tokens = []
|
| 771 |
if prop == "glass transition":
|
| 772 |
tokens = ["tg", "glass transition"]
|
|
|
|
| 791 |
|
| 792 |
return None
|
| 793 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
|
| 795 |
# =============================================================================
|
| 796 |
+
# Tokenizers
|
| 797 |
# =============================================================================
|
| 798 |
class SimpleCharTokenizer:
|
| 799 |
def __init__(self, vocab_chars: List[str], special_tokens=("<pad>", "<s>", "</s>", "<unk>")):
|
|
|
|
| 853 |
blocked.append(tid)
|
| 854 |
setattr(self, "_blocked_ids", blocked)
|
| 855 |
|
|
|
|
| 856 |
if self.PieceToId("*") is None:
|
| 857 |
raise RuntimeError("SentencePiece tokenizer loaded but '*' token not found – aborting for safe PSMILES generation.")
|
| 858 |
|
|
|
|
| 890 |
s = re.sub(r"\*", "[*]", s)
|
| 891 |
return s
|
| 892 |
|
| 893 |
+
|
| 894 |
_AT_BRACKET_UI_RE = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
|
| 895 |
|
| 896 |
+
|
| 897 |
def replace_at_with_star(psmiles: str) -> str:
|
| 898 |
if not isinstance(psmiles, str) or not psmiles:
|
| 899 |
return psmiles
|
| 900 |
return _AT_BRACKET_UI_RE.sub("[*]", psmiles)
|
| 901 |
|
| 902 |
+
|
| 903 |
# =============================================================================
|
| 904 |
+
# SELFIES utilities
|
| 905 |
# =============================================================================
|
| 906 |
_SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]")
|
| 907 |
|
|
|
|
| 947 |
def pselfies_to_psmiles(selfies_str: str) -> str:
|
| 948 |
"""
|
| 949 |
For this orchestrator we treat pSELFIES→PSMILES as SELFIES→canonical SMILES.
|
|
|
|
|
|
|
|
|
|
| 950 |
"""
|
| 951 |
return selfies_to_smiles(selfies_str)
|
| 952 |
|
| 953 |
|
| 954 |
# =============================================================================
|
| 955 |
+
# SELFIES-TED decoder
|
| 956 |
# =============================================================================
|
| 957 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 958 |
SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted")
|
| 959 |
|
|
|
|
| 960 |
GEN_MAX_LEN = 256
|
| 961 |
GEN_MIN_LEN = 10
|
| 962 |
GEN_TOP_P = 0.92
|
| 963 |
GEN_TEMPERATURE = 1.0
|
| 964 |
GEN_REPETITION_PENALTY = 1.05
|
| 965 |
+
LATENT_NOISE_STD_GEN = 0.15
|
| 966 |
|
| 967 |
|
| 968 |
def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
|
|
|
|
|
|
|
|
|
|
| 969 |
import time
|
| 970 |
last_err = None
|
| 971 |
for t in range(max_tries):
|
|
|
|
| 981 |
|
| 982 |
def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
|
| 983 |
"""
|
| 984 |
+
Load tokenizer + seq2seq model for SELFIES-TED.
|
| 985 |
"""
|
| 986 |
def _load_tok():
|
| 987 |
return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True)
|
|
|
|
| 996 |
|
| 997 |
class CLConditionedSelfiesTEDGenerator(nn.Module):
|
| 998 |
"""
|
| 999 |
+
CL embedding (latent) -> fixed-length memory -> conditions SELFIES-TED seq2seq.
|
|
|
|
| 1000 |
"""
|
| 1001 |
def __init__(self, tok, seq2seq_model, cl_emb_dim: int = 600, mem_len: int = 4):
|
| 1002 |
super().__init__()
|
|
|
|
| 1044 |
temperature: float = GEN_TEMPERATURE,
|
| 1045 |
repetition_penalty: float = GEN_REPETITION_PENALTY,
|
| 1046 |
) -> List[str]:
|
|
|
|
|
|
|
|
|
|
| 1047 |
self.eval()
|
| 1048 |
z = z.to(next(self.parameters()).device)
|
| 1049 |
enc_out, attn = self.build_encoder_outputs(z)
|
|
|
|
| 1066 |
|
| 1067 |
|
| 1068 |
# =============================================================================
|
| 1069 |
+
# Latent -> property helper
|
| 1070 |
# =============================================================================
|
| 1071 |
def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1072 |
z_use = np.asarray(z, dtype=np.float32)
|
| 1073 |
if z_use.ndim == 1:
|
| 1074 |
z_use = z_use.reshape(1, -1)
|
| 1075 |
|
|
|
|
| 1076 |
pca = getattr(latent_model, "pca", None)
|
| 1077 |
if pca is not None:
|
| 1078 |
z_use = pca.transform(z_use.astype(np.float32))
|
| 1079 |
|
|
|
|
| 1080 |
gpr = getattr(latent_model, "gpr", None)
|
| 1081 |
if gpr is not None and hasattr(gpr, "predict"):
|
| 1082 |
y_s = gpr.predict(z_use)
|
|
|
|
| 1087 |
|
| 1088 |
y_s = np.array(y_s, dtype=np.float32).reshape(-1)
|
| 1089 |
|
|
|
|
| 1090 |
y_scaler = getattr(latent_model, "y_scaler", None)
|
| 1091 |
if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
|
| 1092 |
y_u = y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1)
|
|
|
|
| 1097 |
|
| 1098 |
|
| 1099 |
# =============================================================================
|
| 1100 |
+
# Legacy models
|
| 1101 |
# =============================================================================
|
| 1102 |
class TransformerDecoderOnly(nn.Module):
|
| 1103 |
def __init__(
|
|
|
|
| 1191 |
|
| 1192 |
|
| 1193 |
# =============================================================================
|
| 1194 |
+
# Orchestrator config
|
| 1195 |
# =============================================================================
|
| 1196 |
class OrchestratorConfig:
|
| 1197 |
+
def __init__(self, paths: Optional[PathsConfig] = None):
|
| 1198 |
+
self.paths = paths or PathsConfig()
|
| 1199 |
+
|
| 1200 |
self.base_dir = "."
|
| 1201 |
+
self.cl_weights_path = self.paths.cl_weights_path
|
| 1202 |
+
self.chroma_db_path = self.paths.chroma_db_path
|
| 1203 |
self.rag_embedding_model = "text-embedding-3-small"
|
| 1204 |
|
| 1205 |
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
|
| 1206 |
self.model = os.getenv("OPENAI_MODEL", "gpt-4.1")
|
| 1207 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1208 |
|
| 1209 |
+
self.spm_model_path = self.paths.spm_model_path
|
| 1210 |
+
self.spm_vocab_path = self.paths.spm_vocab_path
|
| 1211 |
|
| 1212 |
self.springer_api_key = os.getenv("SPRINGER_NATURE_API_KEY", "")
|
| 1213 |
self.semantic_scholar_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
|
|
|
|
| 1219 |
"property_prediction": True,
|
| 1220 |
"polymer_generation": True,
|
| 1221 |
"web_search": True,
|
| 1222 |
+
"report_generation": True, # required by UI
|
| 1223 |
"mol_render": True,
|
| 1224 |
"gen_grid": True,
|
| 1225 |
"prop_attribution": True,
|
|
|
|
| 1264 |
"CrossRef, OpenAlex, EuropePMC, arXiv, Semantic Scholar, Springer Nature (API key), Internet Archive"
|
| 1265 |
),
|
| 1266 |
},
|
| 1267 |
+
"report_generation": {
|
| 1268 |
"name": "Report Generation",
|
| 1269 |
"description": (
|
| 1270 |
"Synthesizes available tool outputs into a single structured report object "
|
|
|
|
| 1295 |
class PolymerOrchestrator:
|
| 1296 |
def __init__(self, config: OrchestratorConfig):
|
| 1297 |
self.config = config
|
| 1298 |
+
|
| 1299 |
+
# Build registries from placeholders (no behavior change; just centralization)
|
| 1300 |
+
self.PROPERTY_HEAD_PATHS, self.PROPERTY_HEAD_META, self.GENERATOR_DIRS = build_property_registries(self.config.paths)
|
| 1301 |
+
|
| 1302 |
self._openai_client = None
|
| 1303 |
self._openai_unavailable_reason = None
|
| 1304 |
self._data_extractor = None
|
|
|
|
| 1315 |
|
| 1316 |
self.system_prompt = self._build_system_prompt()
|
| 1317 |
|
| 1318 |
+
# -------------------------------------------------------------------------
|
| 1319 |
+
# OpenAI client
|
| 1320 |
+
# -------------------------------------------------------------------------
|
| 1321 |
@property
|
| 1322 |
def openai_client(self):
|
| 1323 |
if self._openai_client is None:
|
|
|
|
| 1339 |
return (
|
| 1340 |
"You are the tool-planning module for **PolyAgent**, a polymer-science agent.\n"
|
| 1341 |
"Your job is to inspect the user's questions and decide which tools\n"
|
| 1342 |
+
"to run in which order.\n\n"
|
| 1343 |
"Critical tool dependencies:\n"
|
| 1344 |
"- property_prediction should run AFTER cl_encoding when possible and should reuse cl_encoding.embedding.\n"
|
| 1345 |
"- polymer_generation is inverse-design and REQUIRES target_value (property -> PSMILES).\n\n"
|
|
|
|
| 1348 |
)
|
| 1349 |
|
| 1350 |
# =============================================================================
|
| 1351 |
+
# Planner: LLM tool-calling
|
| 1352 |
# =============================================================================
|
| 1353 |
def analyze_query(self, user_query: str) -> Dict[str, Any]:
|
| 1354 |
schema_keys = ["analysis", "tools_required", "execution_plan"]
|
|
|
|
| 1397 |
}
|
| 1398 |
}
|
| 1399 |
|
|
|
|
| 1400 |
try:
|
| 1401 |
response = self.openai_client.chat.completions.create(
|
| 1402 |
model=self.config.model,
|
|
|
|
| 1422 |
|
| 1423 |
raise RuntimeError("Tool-calling plan not returned; falling back to JSON mode.")
|
| 1424 |
except Exception:
|
|
|
|
| 1425 |
try:
|
| 1426 |
response = self.openai_client.chat.completions.create(
|
| 1427 |
model=self.config.model,
|
|
|
|
| 1467 |
output = self._run_polymer_generation(step, intermediate_data)
|
| 1468 |
elif tool_name == "web_search":
|
| 1469 |
output = self._run_web_search(step, intermediate_data)
|
| 1470 |
+
elif tool_name == "report_generation":
|
| 1471 |
output = self._run_report_generation(step, intermediate_data)
|
| 1472 |
elif tool_name == "mol_render":
|
| 1473 |
output = self._run_mol_render(step, intermediate_data)
|
|
|
|
| 1581 |
"year": meta.get("year", ""),
|
| 1582 |
"source": meta.get("source", meta.get("source_path", "")),
|
| 1583 |
"venue": meta.get("venue", meta.get("journal", "")),
|
|
|
|
| 1584 |
"url": meta.get("url") or meta.get("link") or meta.get("href") or "",
|
| 1585 |
"doi": meta.get("doi") or "",
|
| 1586 |
})
|
|
|
|
| 1705 |
"attention_mask": torch.ones(1, 2048, dtype=torch.bool, device=self.config.device),
|
| 1706 |
}
|
| 1707 |
|
| 1708 |
+
# psmiles encoder input
|
| 1709 |
if self._psmiles_tokenizer is None:
|
| 1710 |
try:
|
| 1711 |
+
from PolyFusion.DeBERTav2 import build_psmiles_tokenizer
|
| 1712 |
self._psmiles_tokenizer = build_psmiles_tokenizer()
|
| 1713 |
except Exception:
|
| 1714 |
self._psmiles_tokenizer = None
|
|
|
|
| 1735 |
with torch.no_grad():
|
| 1736 |
embeddings_dict = self._cl_encoder.encode(batch_mods)
|
| 1737 |
|
|
|
|
| 1738 |
required_modalities = ("gine", "schnet", "fp", "psmiles")
|
| 1739 |
missing = [m for m in required_modalities if m not in embeddings_dict]
|
| 1740 |
if missing:
|
|
|
|
| 1757 |
import torch.nn as nn
|
| 1758 |
|
| 1759 |
property_name = canonical_property_name(property_name)
|
| 1760 |
+
prop_ckpt = self.PROPERTY_HEAD_PATHS.get(property_name)
|
| 1761 |
+
prop_meta = self.PROPERTY_HEAD_META.get(property_name)
|
| 1762 |
|
| 1763 |
if prop_ckpt is None:
|
| 1764 |
raise ValueError(f"No property head registered for: {property_name}")
|
|
|
|
| 1778 |
|
| 1779 |
ckpt = torch.load(prop_ckpt, map_location=self.config.device, weights_only=False)
|
| 1780 |
|
|
|
|
| 1781 |
state_dict = None
|
| 1782 |
for k in ("state_dict", "model_state_dict", "model_state", "head_state_dict", "regressor_state_dict"):
|
| 1783 |
if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
|
|
|
|
| 1803 |
|
| 1804 |
head = RegressionHeadOnly(hidden_dim=600, dropout=float(meta.get("dropout", 0.1))).to(self.config.device)
|
| 1805 |
|
|
|
|
| 1806 |
normalized = {}
|
| 1807 |
for k, v in state_dict.items():
|
| 1808 |
nk = k
|
|
|
|
| 1823 |
head.load_state_dict(normalized, strict=False)
|
| 1824 |
head.eval()
|
| 1825 |
|
|
|
|
| 1826 |
y_scaler = None
|
| 1827 |
if isinstance(ckpt, dict):
|
| 1828 |
for sk in ("y_scaler", "scaler_y", "target_scaler", "y_normalizer"):
|
|
|
|
| 1849 |
return {"error": "Specify property name"}
|
| 1850 |
|
| 1851 |
property_name = canonical_property_name(property_name)
|
| 1852 |
+
if property_name not in self.PROPERTY_HEAD_PATHS:
|
| 1853 |
return {"error": f"Unsupported property: {property_name}"}
|
| 1854 |
|
|
|
|
| 1855 |
emb_from_cl = None
|
| 1856 |
cl = data.get("cl_encoding", None)
|
| 1857 |
if isinstance(cl, dict) and isinstance(cl.get("embedding"), list) and len(cl["embedding"]) == 600:
|
| 1858 |
emb_from_cl = torch.tensor([cl["embedding"]], dtype=torch.float32, device=self.config.device)
|
| 1859 |
|
|
|
|
| 1860 |
multimodal = data.get("data_extraction", None)
|
| 1861 |
psmiles = data.get("psmiles", data.get("smiles", None))
|
| 1862 |
if emb_from_cl is None:
|
|
|
|
| 1874 |
with torch.no_grad():
|
| 1875 |
embs = self._cl_encoder.encode(batch_mods)
|
| 1876 |
|
|
|
|
| 1877 |
required_modalities = ("gine", "schnet", "fp", "psmiles")
|
| 1878 |
missing = [m for m in required_modalities if m not in embs]
|
| 1879 |
if missing:
|
| 1880 |
return {"error": f"CL encoder did not return embeddings for modalities: {', '.join(missing)}"}
|
| 1881 |
|
| 1882 |
all_embs = [embs[k] for k in required_modalities]
|
| 1883 |
+
emb_from_cl = torch.stack(all_embs, dim=0).mean(dim=0)
|
| 1884 |
except Exception as e:
|
| 1885 |
return {"error": f"Failed to compute CL embedding: {e}"}
|
| 1886 |
|
|
|
|
| 1887 |
try:
|
| 1888 |
head, y_scaler, meta, ckpt_path = self._load_property_head(property_name)
|
| 1889 |
with torch.no_grad():
|
|
|
|
| 1891 |
|
| 1892 |
pred_value = float(pred_norm)
|
| 1893 |
|
|
|
|
| 1894 |
if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
|
| 1895 |
try:
|
| 1896 |
inv = y_scaler.inverse_transform(np.array([[pred_norm]], dtype=float))
|
| 1897 |
pred_value = float(inv[0][0])
|
| 1898 |
except Exception:
|
| 1899 |
pred_value = float(pred_norm)
|
|
|
|
|
|
|
| 1900 |
else:
|
| 1901 |
mean = (meta or {}).get("scaler_mean", None)
|
| 1902 |
scale = (meta or {}).get("scaler_scale", None)
|
|
|
|
|
|
|
| 1903 |
try:
|
| 1904 |
if isinstance(mean, list) and isinstance(scale, list) and len(mean) == 1 and len(scale) == 1:
|
| 1905 |
pred_value = float(pred_norm) * float(scale[0]) + float(mean[0])
|
| 1906 |
except Exception:
|
| 1907 |
pred_value = float(pred_norm)
|
| 1908 |
|
|
|
|
| 1909 |
out_psmiles = None
|
| 1910 |
if isinstance(multimodal, dict):
|
| 1911 |
out_psmiles = multimodal.get("canonical_psmiles")
|
|
|
|
| 1920 |
"predictions": {property_name: pred_value},
|
| 1921 |
"prediction_normalized": float(pred_norm),
|
| 1922 |
"head_checkpoint_path": ckpt_path,
|
| 1923 |
+
"metadata_path": self.PROPERTY_HEAD_META.get(property_name, ""),
|
| 1924 |
"normalization_applied": bool(
|
| 1925 |
(y_scaler is not None and hasattr(y_scaler, "inverse_transform")) or
|
| 1926 |
((meta or {}).get("scaler_mean") is not None and (meta or {}).get("scaler_scale") is not None)
|
|
|
|
| 1930 |
except Exception as e:
|
| 1931 |
return {"error": f"Property prediction failed: {e}"}
|
| 1932 |
|
| 1933 |
+
# ----------------- Inverse design generator (CL + SELFIES-TED) ----------------- #
|
| 1934 |
def _get_selfies_ted_backend(self, model_name: str) -> Tuple[Any, Any]:
|
|
|
|
|
|
|
|
|
|
| 1935 |
if not model_name:
|
| 1936 |
model_name = SELFIES_TED_MODEL_NAME
|
| 1937 |
if model_name in self._selfies_ted_cache:
|
|
|
|
| 1942 |
return tok, model
|
| 1943 |
|
| 1944 |
def _load_property_generator(self, property_name: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1945 |
property_name = canonical_property_name(property_name)
|
| 1946 |
if property_name in self._property_generators:
|
| 1947 |
return self._property_generators[property_name]
|
| 1948 |
|
| 1949 |
+
base_dir = self.GENERATOR_DIRS.get(property_name)
|
| 1950 |
if base_dir is None:
|
| 1951 |
raise ValueError(f"No generator registered for: {property_name}")
|
| 1952 |
if not os.path.isdir(base_dir):
|
|
|
|
| 1994 |
if not gpr_path or not os.path.exists(gpr_path):
|
| 1995 |
raise FileNotFoundError(f"GPR *.joblib not found in {base_dir}")
|
| 1996 |
|
|
|
|
| 1997 |
_install_unpickle_shims()
|
| 1998 |
+
scaler_y = _safe_joblib_load(scaler_path)
|
| 1999 |
+
latent_prop_model = _safe_joblib_load(gpr_path)
|
| 2000 |
|
|
|
|
| 2001 |
selfies_ted_name = meta.get("selfies_ted_model", SELFIES_TED_MODEL_NAME)
|
| 2002 |
tok, selfies_backbone = self._get_selfies_ted_backend(selfies_ted_name)
|
| 2003 |
|
|
|
|
| 2012 |
).to(self.config.device)
|
| 2013 |
|
| 2014 |
ckpt = torch.load(decoder_path, map_location=self.config.device, weights_only=False)
|
|
|
|
| 2015 |
state_dict = None
|
| 2016 |
if isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
| 2017 |
state_dict = ckpt
|
|
|
|
| 2051 |
latent_noise_std: float = LATENT_NOISE_STD_GEN,
|
| 2052 |
extra_factor: int = 8,
|
| 2053 |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2054 |
def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
| 2055 |
n = np.linalg.norm(x, axis=-1, keepdims=True)
|
| 2056 |
return x / np.clip(n, eps, None)
|
| 2057 |
|
|
|
|
| 2058 |
if y_scaler is not None and hasattr(y_scaler, "transform"):
|
| 2059 |
target_s = float(y_scaler.transform(np.array([[target_value]], dtype=np.float32))[0, 0])
|
| 2060 |
else:
|
|
|
|
| 2094 |
|
| 2095 |
@torch.no_grad()
|
| 2096 |
def _run_polymer_generation(self, step: Dict, data: Dict) -> Dict:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2097 |
property_name = data.get("property", data.get("property_name", None))
|
| 2098 |
if property_name is None:
|
| 2099 |
return {"error": "Specify property name for generation"}
|
| 2100 |
|
| 2101 |
property_name = canonical_property_name(property_name)
|
| 2102 |
+
if property_name not in self.GENERATOR_DIRS:
|
| 2103 |
return {"error": f"Unsupported property: {property_name}"}
|
| 2104 |
|
|
|
|
| 2105 |
if data.get("target_value", None) is not None:
|
| 2106 |
target_value = data["target_value"]
|
| 2107 |
elif data.get("target", None) is not None:
|
|
|
|
| 2135 |
|
| 2136 |
latent_dim = int(getattr(decoder_model, "cl_emb_dim", 600))
|
| 2137 |
|
|
|
|
| 2138 |
y_scaler = getattr(latent_prop_model, "y_scaler", None)
|
| 2139 |
if y_scaler is None:
|
| 2140 |
y_scaler = scaler_y if scaler_y is not None else None
|
| 2141 |
|
| 2142 |
tol_scaled = float(tol_scaled_override) if tol_scaled_override is not None else float(meta.get("tol_scaled", 0.5))
|
| 2143 |
|
|
|
|
| 2144 |
seed_latents: List[np.ndarray] = []
|
| 2145 |
cl_enc = data.get("cl_encoding", None)
|
| 2146 |
if isinstance(cl_enc, dict) and isinstance(cl_enc.get("embedding"), list):
|
|
|
|
| 2148 |
if emb.shape[0] == latent_dim:
|
| 2149 |
seed_latents.append(emb)
|
| 2150 |
|
|
|
|
| 2151 |
seeds_str: List[str] = []
|
| 2152 |
if isinstance(data.get("seed_psmiles_list"), list):
|
| 2153 |
seeds_str.extend([str(x) for x in data["seed_psmiles_list"] if isinstance(x, str)])
|
|
|
|
| 2155 |
seeds_str.append(str(data["seed_psmiles"]))
|
| 2156 |
if data.get("psmiles") and not seeds_str:
|
| 2157 |
seeds_str.append(str(data["psmiles"]))
|
|
|
|
| 2158 |
seeds_str = list(dict.fromkeys(seeds_str))
|
| 2159 |
|
|
|
|
| 2160 |
if seeds_str and not seed_latents:
|
| 2161 |
self._ensure_cl_encoder()
|
| 2162 |
for s in seeds_str:
|
|
|
|
| 2169 |
if z.shape[0] == latent_dim:
|
| 2170 |
seed_latents.append(z)
|
| 2171 |
|
|
|
|
| 2172 |
try:
|
| 2173 |
Z_keep, y_s_keep, y_u_keep, target_s = self._sample_latents_for_target(
|
| 2174 |
latent_prop_model=latent_prop_model,
|
|
|
|
| 2184 |
except Exception as e:
|
| 2185 |
return {"error": f"Failed to sample latents conditioned on property: {e}", "paths": paths}
|
| 2186 |
|
|
|
|
| 2187 |
at_bracket_re = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
|
| 2188 |
|
| 2189 |
def _at_to_star_bracket(s: str) -> str:
|
|
|
|
| 2192 |
return at_bracket_re.sub("[*]", s)
|
| 2193 |
|
| 2194 |
def _is_rdkit_valid(psmiles: str) -> bool:
|
|
|
|
| 2195 |
if Chem is None:
|
| 2196 |
return True
|
| 2197 |
try:
|
|
|
|
| 2201 |
except Exception:
|
| 2202 |
return False
|
| 2203 |
|
|
|
|
|
|
|
| 2204 |
requested_k = int(num_samples)
|
|
|
|
|
|
|
|
|
|
| 2205 |
candidates: List[Tuple[int, float, str, str, float, float]] = []
|
| 2206 |
|
|
|
|
| 2207 |
candidates_per_latent = max(1, int(extra_factor))
|
| 2208 |
+
max_gen_rounds = 4
|
| 2209 |
|
| 2210 |
Z_round, y_s_round, y_u_round = Z_keep, y_s_keep, y_u_keep
|
| 2211 |
for _round in range(max_gen_rounds):
|
|
|
|
| 2223 |
for selfies_str in (outs or []):
|
| 2224 |
psm_raw = pselfies_to_psmiles(selfies_str)
|
| 2225 |
|
|
|
|
| 2226 |
if _is_rdkit_valid(psm_raw):
|
|
|
|
| 2227 |
psm_out = _at_to_star_bracket(psm_raw)
|
| 2228 |
candidates.append(
|
| 2229 |
(
|
|
|
|
| 2238 |
except Exception:
|
| 2239 |
continue
|
| 2240 |
|
|
|
|
| 2241 |
if len(candidates) >= requested_k:
|
| 2242 |
break
|
| 2243 |
|
|
|
|
| 2244 |
try:
|
| 2245 |
Z_round, y_s_round, y_u_round, target_s = self._sample_latents_for_target(
|
| 2246 |
latent_prop_model=latent_prop_model,
|
|
|
|
| 2256 |
except Exception:
|
| 2257 |
break
|
| 2258 |
|
|
|
|
| 2259 |
candidates.sort(key=lambda t: (t[0], t[1]))
|
| 2260 |
selected = candidates[:requested_k]
|
| 2261 |
|
|
|
|
| 2262 |
if selected and len(selected) < requested_k:
|
| 2263 |
while len(selected) < requested_k:
|
| 2264 |
selected.append(selected[0])
|
|
|
|
| 2272 |
"property": property_name,
|
| 2273 |
"target_value": float(target_value),
|
| 2274 |
"num_samples": int(len(generated_psmiles)),
|
| 2275 |
+
"generated_psmiles": generated_psmiles,
|
| 2276 |
+
"generated_selfies": selfies_raw,
|
| 2277 |
"latent_property_predictions": {
|
| 2278 |
"scaled": decoded_scaled,
|
| 2279 |
"unscaled": decoded_unscaled,
|
|
|
|
| 2324 |
doi = normalize_doi(it.get("DOI", "")) or ""
|
| 2325 |
|
| 2326 |
publisher = (it.get("publisher") or "").lower()
|
|
|
|
| 2327 |
if doi and doi.startswith("10.1163/"):
|
| 2328 |
continue
|
| 2329 |
if "brill" in publisher:
|
| 2330 |
continue
|
| 2331 |
+
|
| 2332 |
pub_year = None
|
| 2333 |
if it.get("published-print") and isinstance(it["published-print"].get("date-parts"), list):
|
| 2334 |
pub_year = it["published-print"]["date-parts"][0][0]
|
|
|
|
| 2340 |
doi = ""
|
| 2341 |
doi_url = ""
|
| 2342 |
|
|
|
|
| 2343 |
landing = (it.get("URL") or "") if isinstance(it.get("URL"), str) else ""
|
| 2344 |
out.append({
|
| 2345 |
"title": title,
|
|
|
|
| 2370 |
continue
|
| 2371 |
|
| 2372 |
doi = normalize_doi(it.get("doi", "")) or ""
|
|
|
|
| 2373 |
if doi and doi.startswith("10.1163/"):
|
| 2374 |
continue
|
| 2375 |
|
|
|
|
| 2386 |
|
| 2387 |
out.append({
|
| 2388 |
"title": it.get("title", ""),
|
| 2389 |
+
"doi": doi,
|
| 2390 |
+
"url": landing or "",
|
| 2391 |
"year": it.get("publication_year") or (it.get("publication_date", "")[:4]),
|
| 2392 |
"venue": (it.get("host_venue") or {}).get("display_name", ""),
|
| 2393 |
+
"type": oa_type,
|
| 2394 |
+
"source": "OpenAlex",
|
| 2395 |
})
|
| 2396 |
return out
|
| 2397 |
except Exception as e:
|
|
|
|
| 2585 |
return {"error": f"Unsupported web_search source: {src}"}
|
| 2586 |
|
| 2587 |
# =============================================================================
|
| 2588 |
+
# REPORT GENERATION
|
| 2589 |
# =============================================================================
|
| 2590 |
def generate_report(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2591 |
payload = dict(data or {})
|
| 2592 |
summary: Dict[str, Any] = {}
|
| 2593 |
|
|
|
|
| 2594 |
prop = payload.get("property") or payload.get("property_name")
|
| 2595 |
if prop:
|
| 2596 |
payload["property"] = prop
|
| 2597 |
|
|
|
|
| 2598 |
if not payload.get("property"):
|
| 2599 |
qtxt = payload.get("questions") or payload.get("question") or ""
|
| 2600 |
inferred_prop = infer_property_from_text(qtxt)
|
|
|
|
| 2605 |
if psmiles:
|
| 2606 |
payload["psmiles"] = psmiles
|
| 2607 |
|
|
|
|
| 2608 |
if payload.get("target_value", None) is None:
|
| 2609 |
qtxt = payload.get("questions") or payload.get("question") or ""
|
| 2610 |
inferred_tgt = infer_target_value_from_text(qtxt, payload.get("property"))
|
| 2611 |
if inferred_tgt is not None:
|
| 2612 |
payload["target_value"] = float(inferred_tgt)
|
| 2613 |
|
|
|
|
| 2614 |
if psmiles and "data_extraction" not in payload:
|
| 2615 |
ex = self._run_data_extraction({"step": -1}, payload)
|
| 2616 |
payload["data_extraction"] = ex
|
| 2617 |
summary["data_extraction"] = ex
|
| 2618 |
|
|
|
|
| 2619 |
if "data_extraction" in payload and "cl_encoding" not in payload:
|
| 2620 |
cl = self._run_cl_encoding({"step": -1}, payload)
|
| 2621 |
payload["cl_encoding"] = cl
|
| 2622 |
summary["cl_encoding"] = cl
|
| 2623 |
|
|
|
|
| 2624 |
if payload.get("property") and "property_prediction" not in payload:
|
| 2625 |
pp = self._run_property_prediction({"step": -1}, payload)
|
| 2626 |
payload["property_prediction"] = pp
|
| 2627 |
summary["property_prediction"] = pp
|
| 2628 |
|
|
|
|
| 2629 |
do_gen = bool(payload.get("generate", False)) or (payload.get("target_value", None) is not None)
|
| 2630 |
if do_gen and payload.get("property") and payload.get("target_value", None) is not None:
|
| 2631 |
gen = self._run_polymer_generation({"step": -1}, payload)
|
| 2632 |
payload["polymer_generation"] = gen
|
| 2633 |
summary["generation"] = gen
|
| 2634 |
|
|
|
|
| 2635 |
q = payload.get("query") or payload.get("literature_query")
|
| 2636 |
src = payload.get("source") or "all"
|
| 2637 |
if q:
|
|
|
|
| 2652 |
"questions": payload.get("questions") or payload.get("question") or "",
|
| 2653 |
}
|
| 2654 |
|
|
|
|
| 2655 |
report = _attach_source_domains(report)
|
| 2656 |
report = _index_citable_sources(report)
|
| 2657 |
report = _assign_tool_tags_to_report(report)
|
|
|
|
| 2661 |
def _run_report_generation(self, step: Dict, data: Dict) -> Dict[str, Any]:
|
| 2662 |
return self.generate_report(data)
|
| 2663 |
|
| 2664 |
+
# =============================================================================
|
| 2665 |
+
# COMPOSER
|
| 2666 |
+
# =============================================================================
|
| 2667 |
def compose_gpt_style_answer(
|
| 2668 |
self,
|
| 2669 |
report: Dict[str, Any],
|
| 2670 |
case_brief: str = "",
|
| 2671 |
questions: str = "",
|
| 2672 |
) -> Tuple[str, List[str]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2673 |
imgs: List[str] = []
|
| 2674 |
|
|
|
|
| 2675 |
if isinstance(report, dict):
|
| 2676 |
report = _attach_source_domains(report)
|
| 2677 |
report = _index_citable_sources(report)
|
| 2678 |
report = _assign_tool_tags_to_report(report)
|
| 2679 |
|
| 2680 |
if self.openai_client is None:
|
|
|
|
| 2681 |
md_lines = []
|
| 2682 |
if case_brief:
|
| 2683 |
md_lines.append(case_brief.strip())
|
|
|
|
| 2692 |
md_lines.append(str(report))
|
| 2693 |
md_lines.append("```")
|
| 2694 |
|
|
|
|
| 2695 |
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
|
| 2696 |
if verb:
|
| 2697 |
md_lines.append("\n---\n\n## Tool outputs (verbatim)\n")
|
|
|
|
| 2699 |
|
| 2700 |
return "\n".join(md_lines), imgs
|
| 2701 |
|
|
|
|
| 2702 |
try:
|
| 2703 |
prompt = (
|
| 2704 |
"You are PolyAgent - consider yourself as an expert in polymer science. Answer the user's questions using ONLY the provided report.\n"
|
|
|
|
| 2714 |
"- NON-DUPLICATES: Do not repeat the same paper link. Each DOI/URL may appear at most once in the entire answer.\n"
|
| 2715 |
"- Each major section should include at least 1 inline literature citation when relevant.\n"
|
| 2716 |
"- Do NOT invent DOIs, URLs, titles, or sources.\n\n"
|
|
|
|
| 2717 |
"OUTPUT RULES (STRICT):\n"
|
| 2718 |
"- If a numeric value is not present in the report, write 'not available'.\n"
|
| 2719 |
"- Preserve polymer endpoint tokens exactly as '[*]' in any pSMILES/SMILES shown.\n"
|
| 2720 |
"- To prevent markdown mangling, put any pSMILES/SMILES inside code formatting.\n"
|
| 2721 |
+
"- Do not rewrite or tweak any tool outputs; if you refer to them, reference them by tag (e.g., [T]).\n\n"
|
| 2722 |
f"CASE BRIEF:\n{case_brief}\n\n"
|
| 2723 |
f"QUESTIONS:\n{questions}\n\n"
|
| 2724 |
f"REPORT (JSON):\n{json.dumps(report, ensure_ascii=False)}\n"
|
|
|
|
| 2734 |
)
|
| 2735 |
txt = resp.choices[0].message.content or ""
|
| 2736 |
|
|
|
|
|
|
|
| 2737 |
try:
|
| 2738 |
min_cites = _infer_required_citation_count(questions or "", default_n=10)
|
| 2739 |
txt = _ensure_distributed_inline_citations(txt, report, min_needed=min_cites)
|
| 2740 |
except Exception:
|
| 2741 |
pass
|
| 2742 |
|
|
|
|
|
|
|
| 2743 |
try:
|
| 2744 |
txt = _normalize_and_dedupe_literature_links(txt, report)
|
| 2745 |
except Exception:
|
|
|
|
| 2750 |
except Exception:
|
| 2751 |
pass
|
| 2752 |
|
|
|
|
| 2753 |
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
|
| 2754 |
if verb:
|
| 2755 |
txt = txt.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
|
| 2756 |
|
| 2757 |
return txt, imgs
|
| 2758 |
except Exception as e:
|
|
|
|
| 2759 |
md = f"OpenAI compose failed: {e}\n\n```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```"
|
|
|
|
| 2760 |
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
|
| 2761 |
if verb:
|
| 2762 |
md = md.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
|
| 2763 |
return md, imgs
|
| 2764 |
|
| 2765 |
# =============================================================================
|
| 2766 |
+
# VISUAL TOOLS
|
| 2767 |
# =============================================================================
|
| 2768 |
def _run_mol_render(self, step: Dict, data: Dict) -> Dict[str, Any]:
|
| 2769 |
out_dir = Path("viz")
|
|
|
|
| 2817 |
return {"png_path": png, "n": len(mols)}
|
| 2818 |
|
| 2819 |
def _run_prop_attribution(self, step: Dict, data: Dict) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2820 |
out_dir = Path("viz")
|
| 2821 |
out_dir.mkdir(parents=True, exist_ok=True)
|
| 2822 |
|
|
|
|
| 2827 |
prop = canonical_property_name(data.get("property") or data.get("property_name") or "glass transition")
|
| 2828 |
top_k = int(data.get("top_k_atoms", data.get("top_k", 12)))
|
| 2829 |
|
|
|
|
| 2830 |
min_rel_importance = float(data.get("min_rel_importance", 0.25))
|
| 2831 |
min_abs_importance = float(data.get("min_abs_importance", 0.0))
|
| 2832 |
|
| 2833 |
+
if prop not in self.PROPERTY_HEAD_PATHS:
|
| 2834 |
return {"error": f"Unsupported property for attribution: {prop}"}
|
| 2835 |
if not p:
|
| 2836 |
return {"error": "no psmiles"}
|
|
|
|
| 2851 |
if not isinstance(baseline, (float, int)):
|
| 2852 |
return {"error": "Baseline prediction not numeric"}
|
| 2853 |
|
|
|
|
| 2854 |
scores: Dict[int, float] = {}
|
| 2855 |
for idx in range(num_atoms):
|
| 2856 |
try:
|
|
|
|
| 2858 |
tmp.GetAtomWithIdx(idx).SetAtomicNum(0) # wildcard
|
| 2859 |
mutated = tmp.GetMol()
|
| 2860 |
mut_smiles = Chem.MolToSmiles(mutated)
|
| 2861 |
+
mut_psmiles = normalize_generated_psmiles_out(mut_smiles)
|
| 2862 |
except Exception:
|
| 2863 |
scores[idx] = 0.0
|
| 2864 |
continue
|
|
|
|
| 2870 |
else:
|
| 2871 |
scores[idx] = float(baseline) - float(mut_val)
|
| 2872 |
|
|
|
|
| 2873 |
max_abs = max((abs(v) for v in scores.values()), default=0.0)
|
| 2874 |
rel_thresh = (min_rel_importance * max_abs) if max_abs > 0 else 0.0
|
| 2875 |
thresh = max(float(min_abs_importance), float(rel_thresh))
|
|
|
|
| 2880 |
selected = [i for i, v in ranked if abs(v) >= thresh]
|
| 2881 |
selected = selected[:k_cap]
|
| 2882 |
|
|
|
|
| 2883 |
if not selected and ranked:
|
| 2884 |
selected = [ranked[0][0]]
|
| 2885 |
|
|
|
|
| 2886 |
atom_colors: Dict[int, tuple] = {}
|
| 2887 |
sel_scores = np.array([scores[i] for i in selected], dtype=float)
|
| 2888 |
if cm is not None and sel_scores.size > 0:
|
|
|
|
| 2930 |
except Exception as e:
|
| 2931 |
return {"error": f"prop_attribution rendering failed: {e}"}
|
| 2932 |
|
|
|
|
| 2933 |
def process_query(self, user_query: str, user_inputs: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 2934 |
plan = self.analyze_query(user_query)
|
| 2935 |
results = self.execute_plan(plan, user_inputs)
|
|
|
|
| 2937 |
|
| 2938 |
|
| 2939 |
if __name__ == "__main__":
|
| 2940 |
+
cfg = OrchestratorConfig(paths=PathsConfig())
|
| 2941 |
orch = PolymerOrchestrator(cfg)
|
| 2942 |
+
print("PolymerOrchestrator ready (5M heads + 5M inverse-design + LLM planner + occlusion explainability).")
|