PolyFusionAgent / PolyAgent /orchestrator.py
kaurm43's picture
Update PolyAgent/orchestrator.py
2b6a0dd verified
"""
PolyAgent Orchestrator
===========================
This file provides a modular orchestrator that:
- extracts polymer multimodal data (graph/geometry/fingerprints/PSMILES)
- encodes CL embeddings using PolyFusion encoders
- predicts single properties using best downstream heads
- performs inverse design using a CL-conditioned SELFIES-TED generator
- retrieves literature via local RAG + web APIs
- visualizes polymer renderings and explainability maps
- composes a final response along with verbatim tool outputs
"""
import os
import re
import json
import pickle
import sys
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from urllib.parse import urlparse
from huggingface_hub import snapshot_download
import numpy as np
import torch
import torch.nn as nn
# HF Transformers (for SELFIES-TED decoder)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.modeling_outputs import BaseModelOutput
# Imports for web fetching
try:
import requests
from bs4 import BeautifulSoup
except Exception:
requests = None
BeautifulSoup = None
# Imports for visuals
try:
from rdkit import Chem
from rdkit.Chem import Draw
except Exception:
Chem = None
Draw = None
try:
from matplotlib import cm
except Exception:
cm = None
# joblib + sentencepiece for 5M generator artifacts
try:
import joblib
except Exception:
joblib = None
try:
import sentencepiece as spm
except Exception:
spm = None
# selfies (for SELFIES→SMILES/PSMILES conversion)
try:
import selfies as sf
except Exception:
sf = None
RDKit_AVAILABLE = Chem is not None
SELFIES_AVAILABLE = sf is not None
# =============================================================================
# PATHS / CONFIGURATION
# =============================================================================
class PathsConfig:
"""
Centralized paths for Spaces/local runs.
On Hugging Face Spaces:
- Downloads required artifacts from a HF Model repo (weights) into a local cache dir
- Exposes stable local filesystem paths used by the rest of orchestrator.py
"""
def __init__(self):
# 1) HF model repo
self.hf_repo_id = os.getenv("POLYFUSION_WEIGHTS_REPO", "kaurm43/polyfusionagent-weights")
self.hf_repo_type = os.getenv("POLYFUSION_WEIGHTS_REPO_TYPE", "model") # usually "model"
# 2) Where to store downloaded files
default_root = "/data/polyfusion_cache" if os.path.isdir("/data") else os.path.expanduser("~/.cache/polyfusion_cache")
self.local_weights_root = os.getenv("POLYFUSION_WEIGHTS_DIR", default_root)
# 3) Optional token (only needed if the weights repo is private)
self.hf_token = os.getenv("HF_TOKEN", None)
# 4) Download (cached) + get local folder path.
allow = [
"tokenizer_spm_5m/**",
"polyfusion_cl_5m/**",
"downstream_heads_5m/**",
"inverse_design_5m/**",
"MANIFEST.txt",
]
self._weights_dir = snapshot_download(
repo_id=self.hf_repo_id,
repo_type=self.hf_repo_type,
local_dir=self.local_weights_root,
local_dir_use_symlinks=False,
token=self.hf_token,
allow_patterns=allow,
)
# 5) Map to the necessary files
self.cl_weights_path = os.path.join(self._weights_dir, "polyfusion_cl_5m", "pytorch_model.bin")
# If your Space also includes a local Chroma DB folder in the Space repo,
# keep this as-is. Otherwise, you can also host Chroma DB as a dataset/model repo.
self.chroma_db_path = os.getenv("CHROMA_DB_PATH", "chroma_polymer_db_big")
self.spm_model_path = os.path.join(self._weights_dir, "tokenizer_spm_5m", "spm.model")
self.spm_vocab_path = os.path.join(self._weights_dir, "tokenizer_spm_5m", "spm.vocab")
self.downstream_bestweights_5m_dir = os.path.join(self._weights_dir, "downstream_heads_5m")
self.inverse_design_5m_dir = os.path.join(self._weights_dir, "inverse_design_5m")
# 6) Optional: sanity-check required files
self._assert_exists(self.cl_weights_path, "CL weights")
self._assert_exists(self.spm_model_path, "SentencePiece model")
self._assert_exists(self.spm_vocab_path, "SentencePiece vocab")
@staticmethod
def _assert_exists(p: str, label: str):
if not os.path.exists(p):
raise FileNotFoundError(f"{label} not found at: {p}")
# =============================================================================
# DOI NORMALIZATION / RESOLUTION HELPERS
# =============================================================================
_DOI_RE = re.compile(r"^10\.\d{4,9}/\S+$", re.IGNORECASE)
def normalize_doi(raw: str) -> Optional[str]:
if not isinstance(raw, str):
return None
s = raw.strip()
if not s:
return None
# remove common prefixes
s = re.sub(r"^(?:https?://(?:dx\.)?doi\.org/)", "", s, flags=re.IGNORECASE)
s = re.sub(r"^doi:\s*", "", s, flags=re.IGNORECASE)
# trim trailing punctuation often attached in text
s = s.rstrip(").,;]}")
return s if _DOI_RE.match(s) else None
def doi_to_url(doi: str) -> str:
# doi is assumed normalized
return f"https://doi.org/{doi}"
def doi_resolves(doi_url: str, timeout: float = 6.0) -> bool:
"""
Best-effort resolver check. Keeps pipeline robust against dead/unregistered DOIs.
If requests is unavailable, do not block.
"""
if requests is None:
return True
try:
r = requests.head(doi_url, allow_redirects=True, timeout=timeout)
if r.status_code == 405:
# Some resolvers disallow HEAD; fall back to a lightweight GET.
r = requests.get(doi_url, allow_redirects=True, timeout=timeout, stream=True)
return 200 <= r.status_code < 400
except Exception:
return False
# =============================================================================
# CITATION / DOMAIN TAGGING HELPERS
# =============================================================================
def _url_to_domain(url: str) -> Optional[str]:
if not isinstance(url, str) or not url.strip():
return None
u = url.strip()
if not (u.startswith("http://") or u.startswith("https://")):
return None
try:
netloc = urlparse(u).netloc.strip().lower()
if netloc.startswith("www."):
netloc = netloc[4:]
# Reduce to ROOT domain (nature.com, springer.com, etc.)
parts = [p for p in netloc.split(".") if p]
if len(parts) <= 2:
return netloc or None
second_level = {
"co.uk", "ac.uk", "gov.uk", "org.uk",
"co.jp", "ne.jp", "or.jp",
"com.au", "net.au", "org.au", "edu.au",
"co.in", "com.br", "com.cn",
}
last2 = ".".join(parts[-2:])
last3 = ".".join(parts[-3:])
if last2 in second_level and len(parts) >= 3:
return last3
if last3 in second_level and len(parts) >= 4:
return ".".join(parts[-4:])
return last2
except Exception:
return None
def _attach_source_domains(obj: Any) -> Any:
"""
Recursively add a short source_domain field where URLs are present.
This enables domain-style citations like "(nature.com)" (note: the composer
later enforces DOI-URL bracket citations for papers).
"""
if isinstance(obj, list):
return [_attach_source_domains(x) for x in obj]
if isinstance(obj, dict):
out: Dict[str, Any] = {}
for k, v in obj.items():
out[k] = _attach_source_domains(v)
for url_key in ("url", "landing_page", "landingPage", "doi_url", "pdf_url", "link", "href"):
v = out.get(url_key)
dom = _url_to_domain(v) if isinstance(v, str) else None
if dom:
out.setdefault("source_domain", dom)
break
return out
return obj
def _index_citable_sources(report: Dict[str, Any]) -> Dict[str, Any]:
"""
Add 'cite_tag' fields for citable web/RAG items using DOI-first URL tags.
Requirement:
- Paper citations must use the COMPLETE DOI URL (https://doi.org/...) as the bracket text.
- If DOI is not available, fall back to the best http(s) URL.
Never uses numbered citations like [1], [2].
"""
if not isinstance(report, dict):
return report
citation_index: Dict[str, Any] = {"sources": []}
def is_citable_item(d: Dict[str, Any]) -> bool:
if not isinstance(d, dict):
return False
for k in ("url", "landing_page", "landingPage", "doi_url", "pdf_url", "link", "href"):
if isinstance(d.get(k), str) and (d[k].startswith("http://") or d[k].startswith("https://")):
return True
if isinstance(d.get("doi"), str) and d["doi"].strip():
return True
return False
def get_best_url(d: Dict[str, Any]) -> Optional[str]:
# DOI-first
doi = normalize_doi(d.get("doi", ""))
if doi:
return doi_to_url(doi)
for k in ("url", "landing_page", "landingPage", "doi_url", "pdf_url", "link", "href"):
v = d.get(k)
if isinstance(v, str) and (v.startswith("http://") or v.startswith("https://")):
return v
return None
def walk_and_tag(node: Any) -> Any:
if isinstance(node, list):
return [walk_and_tag(x) for x in node]
if isinstance(node, dict):
out = {k: walk_and_tag(v) for k, v in node.items()}
if is_citable_item(out):
url = get_best_url(out)
if isinstance(url, str) and url.startswith(("http://", "https://")):
cur = out.get("cite_tag")
if not (isinstance(cur, str) and cur.strip().startswith(("http://", "https://"))):
out["cite_tag"] = url.strip()
url = get_best_url(out)
dom = out.get("source_domain") or (_url_to_domain(url) if url else None) or "source"
citation_index["sources"].append(
{
"tag": out.get("cite_tag") if isinstance(out.get("cite_tag"), str) else url,
"domain": dom,
"title": out.get("title") or out.get("name") or "Untitled",
"url": url,
"doi": out.get("doi"),
}
)
return out
return node
tagged = walk_and_tag(report)
if isinstance(tagged, dict):
tagged.setdefault("citation_index", citation_index)
return tagged
report["citation_index"] = citation_index
return report
# =============================================================================
# INLINE CITATION ENFORCERS (distributed, deduped, DOI-first)
# =============================================================================
_CITE_COUNT_PATTERNS = [
r"(?:at\s+least\s+)?(\d{1,3})\s*(?:citations|citation|papers|paper|sources|source|references|reference)\b",
r"\bcite\s+(\d{1,3})\s*(?:papers|paper|sources|source|references|reference|citations|citation)\b",
r"\b(\d{1,3})\s*(?:papers|paper|sources|source|references|reference|citations|citation)\s*(?:minimum|min)\b",
]
def _infer_required_citation_count(text: str, default_n: int = 10) -> int:
q = (text or "").lower()
for pat in _CITE_COUNT_PATTERNS:
m = re.search(pat, q, flags=re.IGNORECASE)
if m:
try:
n = int(m.group(1))
return max(1, min(n, 200))
except Exception:
pass
return max(1, int(default_n))
def _collect_citation_links_from_report(report: Dict[str, Any]) -> List[Tuple[str, str]]:
"""
Return unique (cite_text, url) pairs from report['citation_index']['sources'].
cite_text is strictly the DOI URL (preferred) or URL fallback.
"""
out: List[Tuple[str, str]] = []
seen: set = set()
if not isinstance(report, dict):
return out
ci = report.get("citation_index", {})
sources = ci.get("sources") if isinstance(ci, dict) else None
if not isinstance(sources, list):
return out
for s in sources:
if not isinstance(s, dict):
continue
url = s.get("url")
if not isinstance(url, str) or not url.startswith(("http://", "https://")):
continue
cite_text = s.get("tag") if isinstance(s.get("tag"), str) and s.get("tag").strip() else url
if not isinstance(cite_text, str) or not cite_text.strip():
cite_text = url
cite_text = cite_text.strip()
key = url.strip()
if key in seen:
continue
seen.add(key)
out.append((cite_text, url.strip()))
return out
def _ensure_distributed_inline_citations(md: str, report: Dict[str, Any], min_needed: int) -> str:
"""
If the model fails to include enough inline clickable paper citations, inject them
in a distributed way (one per eligible paragraph, outside code blocks).
Tool citations ([T]) are NOT modified.
"""
if not isinstance(md, str) or not md.strip():
return md
if not isinstance(report, dict):
return md
if min_needed <= 0:
return md
citations = _collect_citation_links_from_report(report)
if not citations:
return md
existing_urls = set(re.findall(r"\[[^\]]+\]\((https?://[^)]+)\)", md))
need = max(0, int(min_needed) - len(existing_urls))
if need <= 0:
return md
remaining: List[Tuple[str, str]] = [(d, u) for (d, u) in citations if u not in existing_urls]
if not remaining:
return md
parts = re.split(r"(```[\s\S]*?```)", md)
rem_i = 0
for pi, part in enumerate(parts):
if rem_i >= len(remaining) or need <= 0:
break
if part.startswith("```") and part.endswith("```"):
continue
segs = re.split(r"(\n\s*\n)", part)
for si in range(0, len(segs), 2):
if rem_i >= len(remaining) or need <= 0:
break
para = segs[si]
if not isinstance(para, str) or not para.strip():
continue
if para.lstrip().startswith("#"):
continue
if re.search(r"\[[^\]]+\]\((https?://[^)]+)\)", para):
continue
if not re.search(
r"\b(reported|shown|demonstrated|study|studies|literature|evidence|review|according)\b",
para,
flags=re.IGNORECASE,
):
continue
cite_text, url = remaining[rem_i]
segs[si] = para.rstrip() + f" [{cite_text}]({url})"
rem_i += 1
need -= 1
parts[pi] = "".join(segs)
if need > 0 and rem_i < len(remaining):
md2 = "".join(parts)
parts2 = re.split(r"(```[\s\S]*?```)", md2)
for pi, part in enumerate(parts2):
if rem_i >= len(remaining) or need <= 0:
break
if part.startswith("```") and part.endswith("```"):
continue
segs = re.split(r"(\n\s*\n)", part)
for si in range(0, len(segs), 2):
if rem_i >= len(remaining) or need <= 0:
break
para = segs[si]
if not isinstance(para, str) or not para.strip():
continue
if para.lstrip().startswith("#"):
continue
if re.search(r"\[[^\]]+\]\((https?://[^)]+)\)", para):
continue
cite_text, url = remaining[rem_i]
segs[si] = para.rstrip() + f" [{cite_text}]({url})"
rem_i += 1
need -= 1
parts2[pi] = "".join(segs)
return "".join(parts2)
return "".join(parts)
def _normalize_and_dedupe_literature_links(md: str, report: Dict[str, Any]) -> str:
"""
Enforce:
- Link text must be COMPLETE DOI URL (preferred) or URL fallback.
- Each DOI/URL appears at most once in the answer.
Only operates outside fenced code blocks.
"""
if not isinstance(md, str) or not md.strip():
return md
if not isinstance(report, dict):
return md
url_to_text: Dict[str, str] = {}
ci = report.get("citation_index", {})
sources = ci.get("sources") if isinstance(ci, dict) else None
if isinstance(sources, list):
for s in sources:
if not isinstance(s, dict):
continue
url = s.get("url")
if not isinstance(url, str) or not url.startswith(("http://", "https://")):
continue
tag = s.get("tag")
pref = tag.strip() if isinstance(tag, str) and tag.strip() else url.strip()
url_to_text[url.strip()] = pref
parts = re.split(r"(```[\s\S]*?```)", md)
seen_urls: set = set()
def _rewrite_and_dedupe(text: str) -> str:
def repl(m: re.Match) -> str:
url = m.group(2).strip()
if url in seen_urls:
return ""
seen_urls.add(url)
pref = url_to_text.get(url, url)
return f"[{pref}]({url})"
return re.sub(r"\[([^\]]+)\]\((https?://[^)]+)\)", repl, text)
for i, part in enumerate(parts):
if part.startswith("```") and part.endswith("```"):
continue
parts[i] = _rewrite_and_dedupe(part)
parts[i] = re.sub(r"[ \t]{2,}", " ", parts[i])
parts[i] = re.sub(r"\n{3,}", "\n\n", parts[i])
return "".join(parts)
def autolink_doi_urls(md: str) -> str:
"""
Wrap bare DOI URLs in Markdown links outside code blocks.
"""
if not md:
return md
parts = re.split(r"(```[\s\S]*?```)", md)
for i, part in enumerate(parts):
if part.startswith("```") and part.endswith("```"):
continue
parts[i] = re.sub(
r"(?<!\]\()(?P<u>https?://doi\.org/10\.\d{4,9}/[^\s\)\],;]+)",
lambda m: f"[{m.group('u')}]({m.group('u')})",
part,
flags=re.IGNORECASE,
)
return "".join(parts)
# =============================================================================
# TOOL TAGS + VERBATIM TOOL OUTPUT RENDERER
# =============================================================================
def _assign_tool_tags_to_report(report: Dict[str, Any]) -> Dict[str, Any]:
"""
Ensure each tool output has a [T] cite tag.
"""
if not isinstance(report, dict):
return report
tool_outputs = report.get("tool_outputs", {})
if not isinstance(tool_outputs, dict):
return report
preferred = [
"data_extraction",
"cl_encoding",
"property_prediction",
"polymer_generation",
"rag_retrieval",
"web_search",
"report_generation",
]
tool_tag_map: Dict[str, str] = {}
tag = "[T]"
for tool in preferred:
node = tool_outputs.get(tool)
if node is None:
continue
tool_tag_map[tool] = tag
if isinstance(node, dict) and not node.get("cite_tag"):
node["cite_tag"] = tag
for tool, node in tool_outputs.items():
if tool in tool_tag_map or node is None:
continue
tool_tag_map[tool] = tag
if isinstance(node, dict) and not node.get("cite_tag"):
node["cite_tag"] = tag
try:
summary = report.get("summary", {}) or {}
if isinstance(summary, dict):
key_to_tool = {
"data_extraction": "data_extraction",
"cl_encoding": "cl_encoding",
"property_prediction": "property_prediction",
"generation": "polymer_generation",
"polymer_generation": "polymer_generation",
"rag_retrieval": "rag_retrieval",
"web_search": "web_search",
"report_generation": "report_generation",
}
for k, tool in key_to_tool.items():
node = summary.get(k)
if isinstance(node, dict) and tool in tool_tag_map and not node.get("cite_tag"):
node["cite_tag"] = tool_tag_map[tool]
except Exception:
pass
report.setdefault("tool_tag_index", tool_tag_map)
return report
def _render_tool_outputs_verbatim_md(report: Dict[str, Any]) -> str:
"""
Render tool outputs as verbatim JSON blocks.
"""
if not isinstance(report, dict):
return ""
tool_outputs = report.get("tool_outputs", {}) or {}
if not isinstance(tool_outputs, dict):
return ""
preferred = [
"data_extraction",
"cl_encoding",
"property_prediction",
"polymer_generation",
"rag_retrieval",
"web_search",
"report_generation",
]
keys = [k for k in preferred if k in tool_outputs] + [k for k in tool_outputs.keys() if k not in preferred]
chunks: List[str] = []
for k in keys:
out = tool_outputs.get(k)
if out is None:
continue
tag = out.get("cite_tag") if isinstance(out, dict) else None
header = f"### {tag} {k}" if isinstance(tag, str) and tag else f"### {k}"
chunks.append(header)
try:
chunks.append("```json\n" + json.dumps(out, indent=2, ensure_ascii=False) + "\n```")
except Exception:
chunks.append("```text\n" + str(out) + "\n```")
return "\n\n".join(chunks)
# =============================================================================
# PICKLE / JOBLIB COMPATIBILITY SHIMS
# =============================================================================
class LatentPropertyModel:
"""
Compatibility shim for joblib/pickle artifacts saved with references like:
__main__.LatentPropertyModel
"""
def predict(self, X):
for attr in ("model", "gpr", "gpr_model", "estimator", "predictor", "_model", "_gpr"):
if hasattr(self, attr):
obj = getattr(self, attr)
if hasattr(obj, "predict"):
return obj.predict(X)
raise AttributeError(
"LatentPropertyModel shim could not find an underlying predictor. "
"Artifact expects a wrapped model attribute with a .predict method."
)
def _install_unpickle_shims() -> None:
"""
Ensure that any classes pickled under __main__ are available at load time.
"""
main_mod = sys.modules.get("__main__")
if main_mod is not None and not hasattr(main_mod, "LatentPropertyModel"):
setattr(main_mod, "LatentPropertyModel", LatentPropertyModel)
def _safe_joblib_load(path: str):
"""
joblib.load wrapper that patches __main__ symbols on common pickle failures and retries once.
"""
if joblib is None:
raise RuntimeError("joblib not installed but required to load *.joblib artifacts (pip install joblib).")
try:
return joblib.load(path)
except Exception as e:
msg = str(e)
if "Can't get attribute 'LatentPropertyModel' on <module '__main__'" in msg:
_install_unpickle_shims()
return joblib.load(path)
raise
# =============================================================================
# PROPERTY + GENERATOR REGISTRY
# =============================================================================
def build_property_registries(paths: PathsConfig):
"""
Build registry dicts for:
- downstream property heads (checkpoint + metadata)
- inverse-design generator directories
"""
downstream = paths.downstream_bestweights_5m_dir
invgen = paths.inverse_design_5m_dir
PROPERTY_HEAD_PATHS = {
"density": os.path.join(downstream, "density", "best_run_checkpoint.pt"),
"glass transition": os.path.join(downstream, "glass_transition", "best_run_checkpoint.pt"),
"melting": os.path.join(downstream, "melting", "best_run_checkpoint.pt"),
"specific volume": os.path.join(downstream, "specific_volume", "best_run_checkpoint.pt"),
"thermal decomposition": os.path.join(downstream, "thermal_decomposition", "best_run_checkpoint.pt"),
}
PROPERTY_HEAD_META = {
"density": os.path.join(downstream, "density", "best_run_metadata.json"),
"glass transition": os.path.join(downstream, "glass_transition", "best_run_metadata.json"),
"melting": os.path.join(downstream, "melting", "best_run_metadata.json"),
"specific volume": os.path.join(downstream, "specific_volume", "best_run_metadata.json"),
"thermal decomposition": os.path.join(downstream, "thermal_decomposition", "best_run_metadata.json"),
}
GENERATOR_DIRS = {
"density": os.path.join(invgen, "density"),
"glass transition": os.path.join(invgen, "glass_transition"),
"melting": os.path.join(invgen, "melting"),
"specific volume": os.path.join(invgen, "specific_volume"),
"thermal decomposition": os.path.join(invgen, "thermal_decomposition"),
}
return PROPERTY_HEAD_PATHS, PROPERTY_HEAD_META, GENERATOR_DIRS
# =============================================================================
# Property name canonicalization + inference helpers
# =============================================================================
def canonical_property_name(name: str) -> str:
"""
Map user/tool inputs to the canonical keys used in registries.
"""
if not isinstance(name, str):
return ""
s = name.strip().lower()
s = s.replace("_", " ").replace("-", " ")
s = re.sub(r"\s+", " ", s)
aliases = {
"tg": "glass transition",
"glass transition temperature": "glass transition",
"glass transition temp": "glass transition",
"glass transition (tg)": "glass transition",
"t g": "glass transition",
"td": "thermal decomposition",
"thermal decomp": "thermal decomposition",
"thermal decomposition temperature": "thermal decomposition",
"sv": "specific volume",
}
return aliases.get(s, s)
_NUM_RE = r"[-+]?\d+(?:\.\d+)?"
def infer_property_from_text(text: str) -> Optional[str]:
s = (text or "").lower()
m = re.search(r"\bproperty\b\s*[:=]\s*([a-zA-Z _-]+)", s)
if m:
cand = m.group(1).strip().lower()
if "glass" in cand or re.search(r"\btg\b", cand):
return "glass transition"
if "density" in cand or re.search(r"\brho\b", cand):
return "density"
if "melting" in cand or re.search(r"\btm\b", cand):
return "melting"
if "specific" in cand or re.search(r"\bsv\b", cand):
return "specific volume"
if "decomp" in cand or "decomposition" in cand or re.search(r"\btd\b", cand):
return "thermal decomposition"
if "thermal decomposition" in s or "decomposition temperature" in s or "decomposition" in s or re.search(r"\btd\b", s):
return "thermal decomposition"
if "specific volume" in s or re.search(r"\bsv\b", s):
return "specific volume"
if "glass transition" in s or "glass-transition" in s or re.search(r"\btg\b", s):
return "glass transition"
if "melting" in s or "melt temperature" in s or re.search(r"\btm\b", s):
return "melting"
if "density" in s or re.search(r"\brho\b", s):
return "density"
return None
def infer_target_value_from_text(text: str, prop: Optional[str]) -> Optional[float]:
sl = (text or "").lower()
m = re.search(rf"\b(target_value|target|tgt)\b\s*[:=]?\s*({_NUM_RE})", sl)
if m:
try:
return float(m.group(2))
except Exception:
pass
prop = canonical_property_name(prop or "") if prop else ""
patterns = []
if prop == "glass transition":
patterns = [rf"\b(tg|glass\s*transition)\b\s*[:=]?\s*({_NUM_RE})"]
elif prop == "density":
patterns = [rf"\b(density|rho)\b\s*[:=]?\s*({_NUM_RE})"]
elif prop == "melting":
patterns = [rf"\b(tm|melting)\b\s*[:=]?\s*({_NUM_RE})"]
elif prop == "specific volume":
patterns = [rf"\b(specific\s*volume|sv)\b\s*[:=]?\s*({_NUM_RE})"]
elif prop == "thermal decomposition":
patterns = [rf"\b(td|thermal\s*decomposition|decomposition)\b\s*[:=]?\s*({_NUM_RE})"]
for pat in patterns:
m = re.search(pat, sl)
if m:
try:
return float(m.group(m.lastindex))
except Exception:
pass
tokens = []
if prop == "glass transition":
tokens = ["tg", "glass transition"]
elif prop == "density":
tokens = ["density", "rho"]
elif prop == "melting":
tokens = ["tm", "melting"]
elif prop == "specific volume":
tokens = ["specific volume", "sv"]
elif prop == "thermal decomposition":
tokens = ["td", "thermal decomposition", "decomposition"]
for tok in tokens:
for mt in re.finditer(re.escape(tok), sl):
window = sl[mt.end():mt.end() + 80]
mn = re.search(rf"({_NUM_RE})", window)
if mn:
try:
return float(mn.group(1))
except Exception:
pass
return None
# =============================================================================
# Tokenizers
# =============================================================================
class SimpleCharTokenizer:
def __init__(self, vocab_chars: List[str], special_tokens=("<pad>", "<s>", "</s>", "<unk>")):
self.special_tokens = list(special_tokens)
chars = [c for c in vocab_chars if c not in self.special_tokens]
self.vocab = list(self.special_tokens) + chars
self.piece_to_id = {p: i for i, p in enumerate(self.vocab)}
self.id_to_piece = {i: p for i, p in enumerate(self.vocab)}
def encode(self, text: str, out_type=int):
return [self.piece_to_id.get(ch, self.piece_to_id.get("<unk>")) for ch in text]
def decode(self, ids: List[int]) -> str:
pieces = [self.id_to_piece.get(int(i), "") for i in ids]
return "".join([p for p in pieces if p not in self.special_tokens])
def PieceToId(self, piece: str) -> Optional[int]:
return self.piece_to_id.get(piece, None)
def IdToPiece(self, idx: int) -> str:
return self.id_to_piece.get(int(idx), "")
def get_piece_size(self) -> int:
return len(self.vocab)
class SentencePieceTokenizerWrapper:
"""
Minimal wrapper with:
- encode(text) -> List[int]
- decode(ids) -> str
- PieceToId(piece) / IdToPiece(id)
- get_piece_size()
- special_tokens and optional _blocked_ids
"""
def __init__(self, model_path: str):
if spm is None:
raise RuntimeError("sentencepiece not installed but required for spm_5M.model (pip install sentencepiece).")
if not os.path.exists(model_path):
raise FileNotFoundError(f"SentencePiece model not found: {model_path}")
self.model_path = model_path
self.sp = spm.SentencePieceProcessor()
ok = self.sp.Load(model_path)
if not ok:
raise RuntimeError(f"Failed to load SentencePiece model at: {model_path}")
self.special_tokens = []
for t in ("<pad>", "<s>", "</s>", "<unk>"):
if self.sp.PieceToId(t) >= 0:
self.special_tokens.append(t)
blocked = []
for t in ("<pad>", "<unk>"):
tid = self.PieceToId(t)
if tid is not None:
blocked.append(tid)
setattr(self, "_blocked_ids", blocked)
if self.PieceToId("*") is None:
raise RuntimeError("SentencePiece tokenizer loaded but '*' token not found – aborting for safe PSMILES generation.")
def encode(self, text: str, out_type=int):
return list(self.sp.EncodeAsIds(text))
def decode(self, ids: List[int]) -> str:
return self.sp.DecodeIds(list(map(int, ids)))
def PieceToId(self, piece: str) -> Optional[int]:
pid = self.sp.PieceToId(piece)
return None if pid < 0 else int(pid)
def IdToPiece(self, idx: int) -> str:
return self.sp.IdToPiece(int(idx))
def get_piece_size(self) -> int:
return int(self.sp.GetPieceSize())
def normalize_generated_psmiles_out(s: str) -> str:
if not isinstance(s, str):
return s
return re.sub(r"\[\*\]", "*", s)
def psmiles_to_rdkit_smiles(psmiles: str) -> str:
"""
RDKit typically expects wildcard as [*]. Convert '*' -> '[*]' (but keep already-bracketed wildcards).
"""
if not isinstance(psmiles, str):
return ""
s = psmiles
if "*" in s and "[*]" not in s:
s = re.sub(r"\*", "[*]", s)
return s
_AT_BRACKET_UI_RE = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
def replace_at_with_star(psmiles: str) -> str:
if not isinstance(psmiles, str) or not psmiles:
return psmiles
return _AT_BRACKET_UI_RE.sub("[*]", psmiles)
# =============================================================================
# SELFIES utilities
# =============================================================================
_SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]")
def _selfies_compact(selfies_str: str) -> str:
return str(selfies_str).replace(" ", "").strip()
def _ensure_two_at_endpoints(selfies_str: str) -> str:
"""
Simple endpoint regularization. For polymer-style SELFIES this would
normally enforce two special endpoints; here we just compact.
"""
return _selfies_compact(selfies_str)
def selfies_to_smiles(selfies_str: str) -> str:
"""
Decode SELFIES to a canonical SMILES using RDKit, if available.
"""
if not SELFIES_AVAILABLE:
return _selfies_compact(selfies_str)
try:
s = _selfies_compact(selfies_str)
smi = sf.decoder(s)
if not isinstance(smi, str) or not smi:
return s
if not RDKit_AVAILABLE:
return smi
mol = Chem.MolFromSmiles(smi)
if mol is None:
return smi
try:
Chem.SanitizeMol(mol, catchErrors=True)
except Exception:
return smi
return Chem.MolToSmiles(mol, canonical=True)
except Exception:
return _selfies_compact(selfies_str)
def pselfies_to_psmiles(selfies_str: str) -> str:
"""
For this orchestrator we treat pSELFIES→PSMILES as SELFIES→canonical SMILES.
"""
return selfies_to_smiles(selfies_str)
# =============================================================================
# SELFIES-TED decoder
# =============================================================================
HF_TOKEN = os.environ.get("HF_TOKEN", None)
SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted")
GEN_MAX_LEN = 256
GEN_MIN_LEN = 10
GEN_TOP_P = 0.92
GEN_TEMPERATURE = 1.0
GEN_REPETITION_PENALTY = 1.05
LATENT_NOISE_STD_GEN = 0.15
def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
import time
last_err = None
for t in range(max_tries):
try:
return load_fn()
except Exception as e:
last_err = e
sleep_s = base_sleep * (1.6 ** t) + np.random.rand()
print(f"[WARN] HF load attempt {t+1}/{max_tries} failed: {e}. Sleeping {sleep_s:.1f}s then retry.")
time.sleep(sleep_s)
raise RuntimeError(f"Failed to load model from HF. Last error: {last_err}")
def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
"""
Load tokenizer + seq2seq model for SELFIES-TED.
"""
def _load_tok():
return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True)
def _load_model():
return AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
tok = _hf_load_with_retries(_load_tok, max_tries=5)
model = _hf_load_with_retries(_load_model, max_tries=5)
return tok, model
class CLConditionedSelfiesTEDGenerator(nn.Module):
"""
CL embedding (latent) -> fixed-length memory -> conditions SELFIES-TED seq2seq.
"""
def __init__(self, tok, seq2seq_model, cl_emb_dim: int = 600, mem_len: int = 4):
super().__init__()
self.tok = tok
self.model = seq2seq_model
self.mem_len = int(mem_len)
self.cl_emb_dim = int(cl_emb_dim)
d_model = int(getattr(self.model.config, "d_model", getattr(self.model.config, "hidden_size", 1024)))
self.cl_to_d = nn.Sequential(
nn.Linear(self.cl_emb_dim, d_model),
nn.Tanh(),
nn.Dropout(0.1),
nn.Linear(d_model, d_model),
)
self.mem_pos = nn.Embedding(self.mem_len, d_model)
def build_encoder_outputs(self, z: torch.Tensor) -> Tuple[BaseModelOutput, torch.Tensor]:
device = z.device
B = z.size(0)
d = self.cl_to_d(z) # (B, d_model)
d = d.unsqueeze(1).expand(B, self.mem_len, d.size(-1)).contiguous()
pos = torch.arange(self.mem_len, device=device).unsqueeze(0).expand(B, -1)
d = d + self.mem_pos(pos)
attn = torch.ones((B, self.mem_len), dtype=torch.long, device=device)
return BaseModelOutput(last_hidden_state=d), attn
def forward_train(self, z: torch.Tensor, labels: torch.Tensor) -> Dict[str, torch.Tensor]:
enc_out, attn = self.build_encoder_outputs(z)
out = self.model(
encoder_outputs=enc_out,
attention_mask=attn,
labels=labels,
)
loss = out.loss
return {"loss": loss, "ce": loss.detach()}
@torch.no_grad()
def generate(
self,
z: torch.Tensor,
num_return_sequences: int = 1,
max_len: int = GEN_MAX_LEN,
top_p: float = GEN_TOP_P,
temperature: float = GEN_TEMPERATURE,
repetition_penalty: float = GEN_REPETITION_PENALTY,
) -> List[str]:
self.eval()
z = z.to(next(self.parameters()).device)
enc_out, attn = self.build_encoder_outputs(z)
gen = self.model.generate(
encoder_outputs=enc_out,
attention_mask=attn,
do_sample=True,
top_p=float(top_p),
temperature=float(temperature),
repetition_penalty=float(repetition_penalty),
num_return_sequences=int(num_return_sequences),
max_length=int(max_len),
min_length=int(GEN_MIN_LEN),
pad_token_id=int(self.tok.pad_token_id) if self.tok.pad_token_id is not None else None,
eos_token_id=int(self.tok.eos_token_id) if self.tok.eos_token_id is not None else None,
)
outs = self.tok.batch_decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=True)
outs = [_ensure_two_at_endpoints(_selfies_compact(o)) for o in outs]
return outs
# =============================================================================
# Latent -> property helper
# =============================================================================
def _predict_latent_property(latent_model: Any, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
z_use = np.asarray(z, dtype=np.float32)
if z_use.ndim == 1:
z_use = z_use.reshape(1, -1)
pca = getattr(latent_model, "pca", None)
if pca is not None:
z_use = pca.transform(z_use.astype(np.float32))
gpr = getattr(latent_model, "gpr", None)
if gpr is not None and hasattr(gpr, "predict"):
y_s = gpr.predict(z_use)
elif hasattr(latent_model, "predict"):
y_s = latent_model.predict(z_use)
else:
raise RuntimeError("Latent property model has no usable predictor (expected .gpr or .predict).")
y_s = np.array(y_s, dtype=np.float32).reshape(-1)
y_scaler = getattr(latent_model, "y_scaler", None)
if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
y_u = y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1)
else:
y_u = y_s.copy()
return y_s.astype(np.float32), y_u.astype(np.float32)
# =============================================================================
# Legacy models
# =============================================================================
class TransformerDecoderOnly(nn.Module):
def __init__(
self,
vocab_size: int,
hidden_size: int,
num_layers: int = 8,
nhead: int = 10,
ff_mult: int = 4,
dropout: float = 0.1,
tie_embeddings: Optional[nn.Embedding] = None
):
super().__init__()
self.hidden_size = hidden_size
self.token_emb = tie_embeddings if tie_embeddings is not None else nn.Embedding(vocab_size, hidden_size)
self.pos_emb = nn.Embedding(4096, hidden_size)
dec_layer = nn.TransformerDecoderLayer(
d_model=hidden_size,
nhead=nhead,
dim_feedforward=hidden_size * ff_mult,
dropout=dropout,
activation="gelu",
batch_first=True,
)
self.decoder = nn.TransformerDecoder(dec_layer, num_layers=num_layers)
self.ln_f = nn.LayerNorm(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
if tie_embeddings is not None:
self.lm_head.weight = tie_embeddings.weight
def _make_causal_mask(self, L: int, device: torch.device) -> torch.Tensor:
return torch.triu(torch.full((L, L), float("-inf"), device=device), diagonal=1)
def forward(
self,
decoder_input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
decoder_attention_mask: Optional[torch.Tensor] = None
):
B, Ld = decoder_input_ids.size()
device = decoder_input_ids.device
pos_ids = torch.arange(Ld, device=device).unsqueeze(0).expand(B, Ld)
x = self.token_emb(decoder_input_ids) + self.pos_emb(pos_ids)
tgt_mask = self._make_causal_mask(Ld, device)
tgt_key_padding_mask = (decoder_attention_mask == 0) if decoder_attention_mask is not None else None
y = self.decoder(
tgt=x,
memory=encoder_hidden_states,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=None
)
y = self.ln_f(y)
return self.lm_head(y)
class InverseDesignDecoder(nn.Module):
"""
Legacy decoder-only inverse design model (kept for backward compatibility).
The new generation path uses CLConditionedSelfiesTEDGenerator instead.
"""
def __init__(self, vocab_size: int, hidden_size: int = 600, latent_dim: int = 600,
num_memory_tokens: int = 8, decoder_layers: int = 8):
super().__init__()
self.hidden_size = hidden_size
self.latent_dim = latent_dim
self.num_memory_tokens = num_memory_tokens
self.memory_proj = nn.Sequential(
nn.Linear(latent_dim, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size * num_memory_tokens)
)
self.decoder = TransformerDecoderOnly(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_layers=decoder_layers,
nhead=10,
ff_mult=4,
dropout=0.1,
tie_embeddings=None
)
def encode_memory_from_latent(self, latent: torch.Tensor) -> torch.Tensor:
memory_flat = self.memory_proj(latent)
return memory_flat.view(latent.size(0), self.num_memory_tokens, self.hidden_size)
# =============================================================================
# Orchestrator config
# =============================================================================
class OrchestratorConfig:
def __init__(self, paths: Optional[PathsConfig] = None):
self.paths = paths or PathsConfig()
self.base_dir = "."
self.cl_weights_path = self.paths.cl_weights_path
self.chroma_db_path = self.paths.chroma_db_path
self.rag_embedding_model = "text-embedding-3-small"
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
self.model = os.getenv("OPENAI_MODEL", "gpt-4.1")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.spm_model_path = self.paths.spm_model_path
self.spm_vocab_path = self.paths.spm_vocab_path
self.springer_api_key = os.getenv("SPRINGER_NATURE_API_KEY", "")
self.semantic_scholar_key = os.getenv("SEMANTIC_SCHOLAR_API_KEY", "")
self.available_tools = {
"data_extraction": True,
"rag_retrieval": True,
"cl_encoding": True,
"property_prediction": True,
"polymer_generation": True,
"web_search": True,
"report_generation": True, # required by UI
"mol_render": True,
"gen_grid": True,
"prop_attribution": True,
}
TOOL_DESCRIPTIONS = {
"data_extraction": {
"name": "Extract Polymer Multimodal Data",
"description": "Extracts graphs, 3D geometry, fingerprints, and PSMILES",
"input": "PSMILES string or CSV file path",
"output": "JSON with graph, geometry, fingerprints, and canonical PSMILES",
},
"rag_retrieval": {
"name": "RAG Knowledge Base Query",
"description": "Retrieves relevant literature from local polymer KB (Chroma)",
},
"cl_encoding": {
"name": "Contrastive Learning Encoder",
"description": "Encodes polymers using pretrained 4-encoder CL system",
},
"property_prediction": {
"name": "Property Prediction (5M best_run_checkpoint + normalization)",
"description": (
"Predicts polymer properties using CL embeddings + best_run_checkpoint.pt "
"and applies saved normalization to return values in original units. "
"Prefers embeddings from cl_encoding if present."
),
},
"polymer_generation": {
"name": "Inverse Design Generation (5M PolyBART-style)",
"description": (
"Generates polymer PSMILES from a target property using StandardScaler + GPR "
"+ decoder_best_fold*.pt + SELFIES-TED backbone (as in G2.py). "
"Requires target_value; optionally uses CL embeddings from cl_encoding or "
"seed_psmiles to bias the latent sampling."
),
},
"web_search": {
"name": "On-the-fly Literature Search (real & virtual libraries)",
"description": (
"CrossRef, OpenAlex, EuropePMC, arXiv, Semantic Scholar, Springer Nature (API key), Internet Archive"
),
},
"report_generation": {
"name": "Report Generation",
"description": (
"Synthesizes available tool outputs into a single structured report object "
"(summary + tool outputs) that can be rendered by the UI."
),
},
"mol_render": {
"name": "Molecule Rendering",
"description": "2D render of PSMILES with optional highlights (PNG)",
},
"gen_grid": {
"name": "Generation Grid",
"description": "Grid of generated polymers with optional score badges (PNG)",
},
"prop_attribution": {
"name": "Property Attribution",
"description": (
"Per-atom attribution heatmap for predictions using leave-one-atom-out occlusion "
"and top-K highlighting (PNG)."
),
},
}
# =============================================================================
# Orchestrator
# =============================================================================
class PolymerOrchestrator:
def __init__(self, config: OrchestratorConfig):
self.config = config
# Build registries from placeholders (no behavior change; just centralization)
self.PROPERTY_HEAD_PATHS, self.PROPERTY_HEAD_META, self.GENERATOR_DIRS = build_property_registries(self.config.paths)
self._openai_client = None
self._openai_unavailable_reason = None
self._data_extractor = None
self._rag_retriever = None
self._cl_encoder = None
self._psmiles_tokenizer = None
# cached: (head_module, y_scaler, meta, ckpt_path)
self._property_heads: Dict[str, Tuple[torch.nn.Module, Any, Dict[str, Any], str]] = {}
# cached: (decoder_model, latent_prop_model, scaler_y, selfies_tok, meta, paths)
self._property_generators: Dict[str, tuple] = {}
# cached SELFIES-TED backbones keyed by model name
self._selfies_ted_cache: Dict[str, Tuple[Any, Any]] = {}
self.system_prompt = self._build_system_prompt()
# -------------------------------------------------------------------------
# OpenAI client
# -------------------------------------------------------------------------
@property
def openai_client(self):
if self._openai_client is None:
try:
from openai import OpenAI
if not self.config.openai_api_key:
self._openai_unavailable_reason = "OPENAI_API_KEY missing or empty"
self._openai_client = None
else:
self._openai_client = OpenAI(api_key=self.config.openai_api_key)
except Exception as e:
self._openai_unavailable_reason = f"OpenAI client init failed: {e}"
self._openai_client = None
return self._openai_client
def _build_system_prompt(self) -> str:
tools_info = json.dumps(TOOL_DESCRIPTIONS, indent=2)
available = [k for k, v in self.config.available_tools.items() if v]
return (
"You are the tool-planning module for **PolyAgent**, a polymer-science agent.\n"
"Your job is to inspect the user's questions and decide which tools\n"
"to run in which order.\n\n"
"Critical tool dependencies:\n"
"- property_prediction should run AFTER cl_encoding when possible and should reuse cl_encoding.embedding.\n"
"- polymer_generation is inverse-design and REQUIRES target_value (property -> PSMILES).\n\n"
f"Available tools (JSON spec):\n{tools_info}\n\n"
f"Enabled: {', '.join(available)}"
)
# =============================================================================
# Planner: LLM tool-calling
# =============================================================================
def analyze_query(self, user_query: str) -> Dict[str, Any]:
schema_keys = ["analysis", "tools_required", "execution_plan"]
if self.openai_client is None:
return {
"analysis": user_query,
"tools_required": [],
"execution_plan": [],
"note": f"OpenAI unavailable ({self._openai_unavailable_reason or 'unknown'})."
}
sys_prompt = (
self.system_prompt
+ "\nYou must create a tool execution plan. Do not answer the science.\n"
+ "Return a plan with keys exactly: "
+ json.dumps(schema_keys)
)
plan_tool = {
"type": "function",
"function": {
"name": "make_plan",
"description": "Create a tool execution plan for PolyAgent.",
"parameters": {
"type": "object",
"properties": {
"analysis": {"type": "string"},
"tools_required": {"type": "array", "items": {"type": "string"}},
"execution_plan": {
"type": "array",
"items": {
"type": "object",
"properties": {
"step": {"type": "integer"},
"tool": {"type": "string"},
"action": {"type": "string"},
"input": {"type": "string"},
},
"required": ["step", "tool", "action"]
}
}
},
"required": ["analysis", "tools_required", "execution_plan"]
}
}
}
try:
response = self.openai_client.chat.completions.create(
model=self.config.model,
messages=[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_query},
],
tools=[plan_tool],
tool_choice={"type": "function", "function": {"name": "make_plan"}},
temperature=0.2,
max_tokens=700,
)
msg = response.choices[0].message
tool_calls = getattr(msg, "tool_calls", None) or []
if tool_calls:
args = tool_calls[0].function.arguments
plan = json.loads(args)
for k in schema_keys:
if k not in plan:
raise ValueError(f"Missing key '{k}' in tool plan")
return plan
raise RuntimeError("Tool-calling plan not returned; falling back to JSON mode.")
except Exception:
try:
response = self.openai_client.chat.completions.create(
model=self.config.model,
messages=[
{"role": "system", "content": sys_prompt + "\nReturn ONLY a JSON object and nothing else."},
{"role": "user", "content": user_query},
],
temperature=0.2,
max_tokens=700,
response_format={"type": "json_object"},
)
plan = json.loads(response.choices[0].message.content)
for k in schema_keys:
if k not in plan:
raise ValueError(f"Missing key '{k}' in model response")
return plan
except Exception as e:
return {
"analysis": user_query,
"tools_required": [],
"execution_plan": [],
"note": f"OpenAI planning failed: {str(e)}"
}
def execute_plan(self, plan: Dict[str, Any], user_inputs: Dict[str, Any] = None) -> Dict[str, Any]:
results = {"plan": plan, "steps": [], "final_output": None, "errors": []}
intermediate_data = user_inputs or {}
for step in plan.get("execution_plan", []):
step_num = step.get("step", 0)
tool_name = step.get("tool", "")
action = step.get("action", "")
try:
if tool_name == "data_extraction":
output = self._run_data_extraction(step, intermediate_data)
elif tool_name == "rag_retrieval":
output = self._run_rag_retrieval(step, intermediate_data)
elif tool_name == "cl_encoding":
output = self._run_cl_encoding(step, intermediate_data)
elif tool_name == "property_prediction":
output = self._run_property_prediction(step, intermediate_data)
elif tool_name == "polymer_generation":
output = self._run_polymer_generation(step, intermediate_data)
elif tool_name == "web_search":
output = self._run_web_search(step, intermediate_data)
elif tool_name == "report_generation":
output = self._run_report_generation(step, intermediate_data)
elif tool_name == "mol_render":
output = self._run_mol_render(step, intermediate_data)
elif tool_name == "gen_grid":
output = self._run_gen_grid(step, intermediate_data)
elif tool_name == "prop_attribution":
output = self._run_prop_attribution(step, intermediate_data)
else:
output = {"error": f"Unknown tool: {tool_name}"}
results["steps"].append({"step": step_num, "tool": tool_name, "action": action, "output": output})
intermediate_data[f"step_{step_num}_output"] = output
intermediate_data[tool_name] = output
except Exception as e:
results["errors"].append(f"Error in step {step_num} ({tool_name}): {str(e)}")
if results["steps"]:
results["final_output"] = results["steps"][-1]["output"]
return results
# ----------------- Data extraction ----------------- #
def _run_data_extraction(self, step: Dict, data: Dict) -> Dict:
if self._data_extractor is None:
try:
from Data_Modalities import AdvancedPolymerMultimodalExtractor
except Exception as e:
return {"error": f"Data_Modalities import failed: {e}"}
self._data_extractor = AdvancedPolymerMultimodalExtractor(csv_file="")
psmiles = data.get("psmiles", data.get("smiles", "")) or data.get("seed_psmiles", "")
if not psmiles:
return {"error": "No PSMILES provided"}
canonical = self._data_extractor.validate_and_standardize_smiles(psmiles)
if not canonical:
return {"error": f"Invalid PSMILES: {psmiles}"}
return {
"canonical_psmiles": canonical,
"graph": self._data_extractor.generate_molecular_graph(canonical),
"geometry": self._data_extractor.optimize_3d_geometry(canonical),
"fingerprints": self._data_extractor.calculate_morgan_fingerprints(canonical),
}
# ----------------- RAG retrieval ----------------- #
def _run_rag_retrieval(self, step: Dict, data: Dict) -> Dict:
try:
from rag_pipeline import (
build_retriever_from_web,
build_retriever,
POLYMER_KEYWORDS,
DEFAULT_TMP_DOWNLOAD_DIR,
DEFAULT_MAILTO,
PolymerStyleOpenAIEmbeddings,
)
from langchain_community.vectorstores import Chroma
except Exception as e:
return {"error": f"Could not import polymer rag_pipeline: {e}"}
if self._rag_retriever is None:
try:
persist_dir = self.config.chroma_db_path
if os.path.isdir(persist_dir) and os.listdir(persist_dir):
embeddings = PolymerStyleOpenAIEmbeddings(
model=self.config.rag_embedding_model,
api_key=self.config.openai_api_key
)
vector_store = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
self._rag_retriever = vector_store.as_retriever(search_kwargs={"k": 6})
else:
papers_dir = DEFAULT_TMP_DOWNLOAD_DIR
pdfs_present = os.path.isdir(papers_dir) and any(f.lower().endswith(".pdf") for f in os.listdir(papers_dir))
if pdfs_present:
self._rag_retriever = build_retriever(
papers_path=papers_dir,
persist_dir=persist_dir,
k=6,
embedding_model=self.config.rag_embedding_model,
vector_backend="chroma",
)
else:
self._rag_retriever = build_retriever_from_web(
polymer_keywords=POLYMER_KEYWORDS,
persist_dir=persist_dir,
tmp_download_dir=papers_dir,
k=6,
embedding_model=self.config.rag_embedding_model,
vector_backend="chroma",
mailto=DEFAULT_MAILTO,
)
except Exception as e:
return {"error": f"Failed to initialize RAG retriever: {e}"}
query = data.get("query", data.get("question", step.get("input", ""))) or ""
if not query:
return {"error": "No query provided"}
try:
docs = self._rag_retriever.get_relevant_documents(query)
except Exception as e:
return {"error": f"RAG retrieval failed: {e}"}
results = []
for i, doc in enumerate(docs or [], 1):
meta = getattr(doc, "metadata", {}) or {}
page_content = getattr(doc, "page_content", "") or ""
results.append({
"rank": i,
"content": page_content[:800],
"title": meta.get("title", "Unknown"),
"year": meta.get("year", ""),
"source": meta.get("source", meta.get("source_path", "")),
"venue": meta.get("venue", meta.get("journal", "")),
"url": meta.get("url") or meta.get("link") or meta.get("href") or "",
"doi": meta.get("doi") or "",
})
return {"query": query, "results": results}
# ----------------- CL encoding ----------------- #
def _ensure_cl_encoder(self):
if self._cl_encoder is None:
try:
from PolyFusion.GINE import GineEncoder, GineBlock, MaskedGINE, match_edge_attr_to_index, safe_get
from PolyFusion.SchNet import NodeSchNetWrapper
from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder
from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer
from PolyFusion.CL import MultimodalContrastiveModel
except Exception:
raise RuntimeError("Modules not available in python path")
if self._psmiles_tokenizer is None:
self._psmiles_tokenizer = build_psmiles_tokenizer(
spm_path=self.config.spm_model_path,
max_len=128,
)
vocab_sz = len(self._psmiles_tokenizer)
pad_id = self._psmiles_tokenizer.pad_token_id if self._psmiles_tokenizer.pad_token_id is not None else 0
gine = GineEncoder().to(self.config.device)
schnet = NodeSchNetWrapper().to(self.config.device)
fp = FingerprintEncoder().to(self.config.device)
psm = PSMILESDebertaEncoder(
model_dir_or_name=None,
vocab_size=vocab_sz,
pad_token_id=pad_id,
).to(self.config.device)
model = MultimodalContrastiveModel(gine, schnet, fp, psm, emb_dim=600).to(self.config.device)
try:
state_dict = torch.load(self.config.cl_weights_path, map_location=self.config.device, weights_only=False)
model.load_state_dict(state_dict, strict=False)
except Exception:
pass
model.eval()
self._cl_encoder = model
def _prepare_batch_from_extraction(self, multimodal_data: Dict) -> Dict:
batch: Dict[str, Dict[str, torch.Tensor]] = {}
# graph
if "graph" in multimodal_data:
graph = multimodal_data["graph"]
node_features = graph.get("node_features", [])
if len(node_features) > 0:
atomic_nums, chirality, formal_charge = [], [], []
for nf in node_features:
atomic_nums.append(int(nf.get("atomic_num", nf.get("atomic_number", 6))))
chirality.append(float(nf.get("chirality", 0)))
formal_charge.append(float(nf.get("formal_charge", 0)))
z_tensor = torch.tensor(atomic_nums, dtype=torch.long, device=self.config.device)
chirality_tensor = torch.tensor(chirality, dtype=torch.float, device=self.config.device)
formal_charge_tensor = torch.tensor(formal_charge, dtype=torch.float, device=self.config.device)
edge_indices = graph.get("edge_indices", [])
if edge_indices:
ei = torch.tensor(edge_indices, dtype=torch.long, device=self.config.device)
if ei.dim() == 2 and ei.size(1) == 2:
edge_index = ei.t().contiguous()
elif ei.dim() == 2 and ei.size(0) == 2:
edge_index = ei.contiguous()
else:
edge_index = torch.tensor([[], []], dtype=torch.long, device=self.config.device)
else:
edge_index = torch.tensor([[], []], dtype=torch.long, device=self.config.device)
edge_features = graph.get("edge_features", [])
if edge_features:
edge_attr = torch.tensor(
[[ef.get("bond_type", 0), ef.get("stereo", 0), float(ef.get("is_conjugated", False))]
for ef in edge_features],
dtype=torch.float, device=self.config.device,
)
else:
edge_attr = torch.zeros((edge_index.size(1), 3), dtype=torch.float, device=self.config.device)
# reconcile sizes
num_ei = edge_index.size(1)
num_ea = edge_attr.size(0)
if num_ei != num_ea:
if num_ei == 0:
edge_attr = torch.zeros((0, 3), dtype=torch.float, device=self.config.device)
elif num_ea > num_ei:
edge_attr = edge_attr[:num_ei].contiguous()
else:
pad = torch.zeros((num_ei - num_ea, 3), dtype=torch.float, device=self.config.device)
edge_attr = torch.cat([edge_attr, pad], dim=0)
batch["gine"] = {
"z": z_tensor,
"chirality": chirality_tensor,
"formal_charge": formal_charge_tensor,
"edge_index": edge_index,
"edge_attr": edge_attr,
"batch": torch.zeros(z_tensor.size(0), dtype=torch.long, device=self.config.device),
}
# geometry
if "geometry" in multimodal_data:
geom = multimodal_data["geometry"]
best_conf = geom.get("best_conformer", {})
if best_conf:
atomic_numbers = best_conf.get("atomic_numbers", [])
coordinates = best_conf.get("coordinates", [])
if atomic_numbers and coordinates:
batch["schnet"] = {
"z": torch.tensor(atomic_numbers, dtype=torch.long, device=self.config.device),
"pos": torch.tensor(coordinates, dtype=torch.float, device=self.config.device),
"batch": torch.zeros(len(atomic_numbers), dtype=torch.long, device=self.config.device),
}
# fingerprints
if "fingerprints" in multimodal_data:
fp_dict = multimodal_data["fingerprints"]
morgan_bits = fp_dict.get("morgan_r3_bits", [])
if morgan_bits:
fp_vec = [1 if b else 0 for b in morgan_bits[:2048]]
if len(fp_vec) < 2048:
fp_vec += [0] * (2048 - len(fp_vec))
batch["fp"] = {
"input_ids": torch.tensor(fp_vec, dtype=torch.long, device=self.config.device).unsqueeze(0),
"attention_mask": torch.ones(1, 2048, dtype=torch.bool, device=self.config.device),
}
# psmiles encoder input
if self._psmiles_tokenizer is None:
try:
from PolyFusion.DeBERTav2 import build_psmiles_tokenizer
self._psmiles_tokenizer = build_psmiles_tokenizer(
spm_path=self.config.spm_model_path,
max_len=128,
)
except Exception:
self._psmiles_tokenizer = None
psmiles_str = multimodal_data.get("canonical_psmiles", "")
if psmiles_str and self._psmiles_tokenizer is not None:
enc = self._psmiles_tokenizer(psmiles_str, truncation=True, padding="max_length", max_length=128)
batch["psmiles"] = {
"input_ids": torch.tensor(enc["input_ids"], dtype=torch.long, device=self.config.device).unsqueeze(0),
"attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.long, device=self.config.device).unsqueeze(0),
}
return batch
def _run_cl_encoding(self, step: Dict, data: Dict) -> Dict:
multimodal_data = data.get("data_extraction", {})
if not multimodal_data or "canonical_psmiles" not in multimodal_data:
return {"error": "No multimodal data found. Run data_extraction first."}
self._ensure_cl_encoder()
try:
batch_mods = self._prepare_batch_from_extraction(multimodal_data)
with torch.no_grad():
embeddings_dict = self._cl_encoder.encode(batch_mods)
required_modalities = ("gine", "schnet", "fp", "psmiles")
missing = [m for m in required_modalities if m not in embeddings_dict]
if missing:
return {"error": f"Missing CL embeddings for modalities: {', '.join(missing)}"}
all_embs = [embeddings_dict[k] for k in required_modalities]
final_embedding = torch.stack(all_embs, dim=0).mean(dim=0).squeeze(0).contiguous()
return {
"embedding": final_embedding.detach().cpu().tolist(),
"embedding_dim": int(final_embedding.shape[-1]),
"modalities_used": list(required_modalities),
"psmiles": multimodal_data["canonical_psmiles"],
}
except Exception as e:
return {"error": f"Failed to encode: {e}"}
# ----------------- Property heads (downstream) ----------------- #
def _load_property_head(self, property_name: str):
import torch.nn as nn
property_name = canonical_property_name(property_name)
prop_ckpt = self.PROPERTY_HEAD_PATHS.get(property_name)
prop_meta = self.PROPERTY_HEAD_META.get(property_name)
if prop_ckpt is None:
raise ValueError(f"No property head registered for: {property_name}")
if not os.path.exists(prop_ckpt):
raise FileNotFoundError(f"Property head checkpoint not found: {prop_ckpt}")
if property_name in self._property_heads:
return self._property_heads[property_name]
meta: Dict[str, Any] = {}
if prop_meta and os.path.exists(prop_meta):
try:
with open(prop_meta, "r") as fh:
meta = json.load(fh)
except Exception:
meta = {}
ckpt = torch.load(prop_ckpt, map_location=self.config.device, weights_only=False)
state_dict = None
for k in ("state_dict", "model_state_dict", "model_state", "head_state_dict", "regressor_state_dict"):
if isinstance(ckpt, dict) and k in ckpt and isinstance(ckpt[k], dict):
state_dict = ckpt[k]
break
if state_dict is None and isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
state_dict = ckpt
if state_dict is None:
raise RuntimeError(f"Could not find a usable state dict in {prop_ckpt}")
class RegressionHeadOnly(nn.Module):
def __init__(self, hidden_dim=600, dropout=0.1):
super().__init__()
self.head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x):
return self.head(x).squeeze(-1)
head = RegressionHeadOnly(hidden_dim=600, dropout=float(meta.get("dropout", 0.1))).to(self.config.device)
normalized = {}
for k, v in state_dict.items():
nk = k
if nk.startswith("module."):
nk = nk[len("module."):]
if nk.startswith("model."):
nk = nk[len("model."):]
if nk.startswith("regressor."):
nk = nk.replace("regressor.", "head.", 1)
if nk.startswith("head."):
normalized[nk] = v
elif re.match(r"^\d+\.", nk):
normalized["head." + nk] = v
else:
normalized["head." + nk] = v
head.load_state_dict(normalized, strict=False)
head.eval()
y_scaler = None
if isinstance(ckpt, dict):
for sk in ("y_scaler", "scaler_y", "target_scaler", "y_normalizer"):
if sk in ckpt:
y_scaler = ckpt.get(sk)
break
if y_scaler is None and isinstance(meta, dict) and joblib is not None:
for path_key in ("y_scaler_path", "target_scaler_path", "scaler_path", "y_norm_path"):
spath = meta.get(path_key)
if spath and isinstance(spath, str) and os.path.exists(spath):
try:
y_scaler = joblib.load(spath)
break
except Exception:
y_scaler = None
self._property_heads[property_name] = (head, y_scaler, meta, prop_ckpt)
return self._property_heads[property_name]
def _run_property_prediction(self, step: Dict, data: Dict) -> Dict:
property_name = data.get("property", data.get("property_name", None))
if property_name is None:
return {"error": "Specify property name"}
property_name = canonical_property_name(property_name)
if property_name not in self.PROPERTY_HEAD_PATHS:
return {"error": f"Unsupported property: {property_name}"}
emb_from_cl = None
cl = data.get("cl_encoding", None)
if isinstance(cl, dict) and isinstance(cl.get("embedding"), list) and len(cl["embedding"]) == 600:
emb_from_cl = torch.tensor([cl["embedding"]], dtype=torch.float32, device=self.config.device)
multimodal = data.get("data_extraction", None)
psmiles = data.get("psmiles", data.get("smiles", None))
if emb_from_cl is None:
if psmiles and not multimodal:
multimodal = self._run_data_extraction({"step": -1}, {"psmiles": psmiles})
if "error" in multimodal:
return multimodal
data["data_extraction"] = multimodal
if not multimodal or "canonical_psmiles" not in multimodal:
return {"error": "No multimodal data; provide psmiles or data_extraction first."}
self._ensure_cl_encoder()
try:
batch_mods = self._prepare_batch_from_extraction(multimodal)
with torch.no_grad():
embs = self._cl_encoder.encode(batch_mods)
required_modalities = ("gine", "schnet", "fp", "psmiles")
missing = [m for m in required_modalities if m not in embs]
if missing:
return {"error": f"CL encoder did not return embeddings for modalities: {', '.join(missing)}"}
all_embs = [embs[k] for k in required_modalities]
emb_from_cl = torch.stack(all_embs, dim=0).mean(dim=0)
except Exception as e:
return {"error": f"Failed to compute CL embedding: {e}"}
try:
head, y_scaler, meta, ckpt_path = self._load_property_head(property_name)
with torch.no_grad():
pred_norm = head(emb_from_cl).squeeze(0).item()
pred_value = float(pred_norm)
if y_scaler is not None and hasattr(y_scaler, "inverse_transform"):
try:
inv = y_scaler.inverse_transform(np.array([[pred_norm]], dtype=float))
pred_value = float(inv[0][0])
except Exception:
pred_value = float(pred_norm)
else:
mean = (meta or {}).get("scaler_mean", None)
scale = (meta or {}).get("scaler_scale", None)
try:
if isinstance(mean, list) and isinstance(scale, list) and len(mean) == 1 and len(scale) == 1:
pred_value = float(pred_norm) * float(scale[0]) + float(mean[0])
except Exception:
pred_value = float(pred_norm)
out_psmiles = None
if isinstance(multimodal, dict):
out_psmiles = multimodal.get("canonical_psmiles")
if out_psmiles is None and isinstance(cl, dict):
out_psmiles = cl.get("psmiles")
if out_psmiles is None:
out_psmiles = psmiles
return {
"psmiles": out_psmiles,
"property": property_name,
"predictions": {property_name: pred_value},
"prediction_normalized": float(pred_norm),
"head_checkpoint_path": ckpt_path,
"metadata_path": self.PROPERTY_HEAD_META.get(property_name, ""),
"normalization_applied": bool(
(y_scaler is not None and hasattr(y_scaler, "inverse_transform")) or
((meta or {}).get("scaler_mean") is not None and (meta or {}).get("scaler_scale") is not None)
),
"used_cl_embedding": True,
}
except Exception as e:
return {"error": f"Property prediction failed: {e}"}
# ----------------- Inverse design generator (CL + SELFIES-TED) ----------------- #
def _get_selfies_ted_backend(self, model_name: str) -> Tuple[Any, Any]:
if not model_name:
model_name = SELFIES_TED_MODEL_NAME
if model_name in self._selfies_ted_cache:
return self._selfies_ted_cache[model_name]
tok, model = load_selfies_ted_and_tokenizer(model_name)
model.to(self.config.device)
self._selfies_ted_cache[model_name] = (tok, model)
return tok, model
def _load_property_generator(self, property_name: str):
property_name = canonical_property_name(property_name)
if property_name in self._property_generators:
return self._property_generators[property_name]
base_dir = self.GENERATOR_DIRS.get(property_name)
if base_dir is None:
raise ValueError(f"No generator registered for: {property_name}")
if not os.path.isdir(base_dir):
raise FileNotFoundError(f"Generator directory not found: {base_dir}")
meta_path = os.path.join(base_dir, "meta.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(f"meta.json not found in {base_dir}")
if joblib is None:
raise RuntimeError("joblib not installed but required to load *.joblib artifacts (pip install joblib).")
with open(meta_path, "r") as fh:
meta = json.load(fh) if fh else {}
best_fold = None
for k in ("best_fold", "selected_fold", "fold", "bestFold", "best_fold_idx"):
if k in meta:
try:
best_fold = int(meta[k])
break
except Exception:
best_fold = None
if best_fold is None:
best_fold = 1
decoder_path = os.path.join(base_dir, f"decoder_best_fold{best_fold}.pt")
if not os.path.exists(decoder_path):
decs = sorted([p for p in os.listdir(base_dir) if p.startswith("decoder_best_fold") and p.endswith(".pt")])
if not decs:
raise FileNotFoundError(f"No decoder_best_fold*.pt found in {base_dir}")
decoder_path = os.path.join(base_dir, decs[0])
scaler_path = None
gpr_path = None
for fn in os.listdir(base_dir):
low = fn.lower()
if low.startswith("standardscaler_") and low.endswith(".joblib"):
scaler_path = os.path.join(base_dir, fn)
if low.startswith("gpr_psmiles_") and low.endswith(".joblib"):
gpr_path = os.path.join(base_dir, fn)
if not scaler_path or not os.path.exists(scaler_path):
raise FileNotFoundError(f"StandardScaler *.joblib not found in {base_dir}")
if not gpr_path or not os.path.exists(gpr_path):
raise FileNotFoundError(f"GPR *.joblib not found in {base_dir}")
_install_unpickle_shims()
scaler_y = _safe_joblib_load(scaler_path)
latent_prop_model = _safe_joblib_load(gpr_path)
selfies_ted_name = meta.get("selfies_ted_model", SELFIES_TED_MODEL_NAME)
tok, selfies_backbone = self._get_selfies_ted_backend(selfies_ted_name)
cl_emb_dim = int(meta.get("cl_emb_dim", 600))
mem_len = int(meta.get("mem_len", 4))
decoder_model = CLConditionedSelfiesTEDGenerator(
tok=tok,
seq2seq_model=selfies_backbone,
cl_emb_dim=cl_emb_dim,
mem_len=mem_len,
).to(self.config.device)
ckpt = torch.load(decoder_path, map_location=self.config.device, weights_only=False)
state_dict = None
if isinstance(ckpt, dict) and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
state_dict = ckpt
elif isinstance(ckpt, dict):
for k in ("state_dict", "model_state_dict", "decoder_state_dict"):
if k in ckpt and isinstance(ckpt[k], dict):
state_dict = ckpt[k]
break
if state_dict is None:
raise RuntimeError(f"Could not find a usable state dict in decoder checkpoint: {decoder_path}")
decoder_model.load_state_dict(state_dict, strict=False)
decoder_model.eval()
paths = {
"base_dir": base_dir,
"meta_json": meta_path,
"decoder_checkpoint": decoder_path,
"scaler_joblib": scaler_path,
"gpr_joblib": gpr_path,
"selfies_ted_model": selfies_ted_name,
}
self._property_generators[property_name] = (decoder_model, latent_prop_model, scaler_y, tok, meta, paths)
return self._property_generators[property_name]
@torch.no_grad()
def _sample_latents_for_target(
self,
latent_prop_model: Any,
target_value: float,
num_samples: int,
latent_dim: int,
tol_scaled: float,
y_scaler: Optional[Any] = None,
seed_latents: Optional[List[np.ndarray]] = None,
latent_noise_std: float = LATENT_NOISE_STD_GEN,
extra_factor: int = 8,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
n = np.linalg.norm(x, axis=-1, keepdims=True)
return x / np.clip(n, eps, None)
if y_scaler is not None and hasattr(y_scaler, "transform"):
target_s = float(y_scaler.transform(np.array([[target_value]], dtype=np.float32))[0, 0])
else:
target_s = float(target_value)
n_candidates = max(num_samples * extra_factor, num_samples * 4, 64)
latents: List[np.ndarray] = []
if seed_latents:
seeds = [np.asarray(z, dtype=np.float32).reshape(-1) for z in seed_latents]
for z0 in seeds:
z0 = _l2_normalize_np(z0.reshape(1, -1)).reshape(-1)
latents.append(z0)
per_seed = max(1, n_candidates // max(1, len(seeds)) - 1)
for _ in range(per_seed):
noise = np.random.randn(latent_dim).astype(np.float32) * float(latent_noise_std)
z = z0 + noise
z = _l2_normalize_np(z.reshape(1, -1)).reshape(-1)
latents.append(z)
else:
for _ in range(n_candidates):
z = np.random.randn(latent_dim).astype(np.float32)
z = _l2_normalize_np(z.reshape(1, -1)).reshape(-1)
latents.append(z)
Z = np.stack(latents, axis=0).astype(np.float32)
y_s, y_u = _predict_latent_property(latent_prop_model, Z)
errors = np.abs(y_s - target_s)
idx_sorted = np.argsort(errors)
kept = [i for i in idx_sorted if errors[i] <= float(tol_scaled)]
if len(kept) < num_samples:
kept = list(idx_sorted[:num_samples])
kept = kept[:num_samples]
return Z[kept], y_s[kept], y_u[kept], target_s
@torch.no_grad()
def _run_polymer_generation(self, step: Dict, data: Dict) -> Dict:
property_name = data.get("property", data.get("property_name", None))
if property_name is None:
return {"error": "Specify property name for generation"}
property_name = canonical_property_name(property_name)
if property_name not in self.GENERATOR_DIRS:
return {"error": f"Unsupported property: {property_name}"}
if data.get("target_value", None) is not None:
target_value = data["target_value"]
elif data.get("target", None) is not None:
target_value = data["target"]
elif data.get("target_property_value", None) is not None:
target_value = data["target_property_value"]
else:
return {"error": "Generation requires target_value (inverse design: property -> PSMILES)."}
try:
target_value = float(target_value)
except Exception:
return {"error": f"target_value must be numeric, got: {target_value!r}"}
num_samples = int(data.get("num_samples", 4))
if num_samples < 1:
num_samples = 1
top_p = float(data.get("top_p", GEN_TOP_P))
temperature = float(data.get("temperature", GEN_TEMPERATURE))
rep_pen = float(data.get("repetition_penalty", GEN_REPETITION_PENALTY))
max_len = int(data.get("max_len", GEN_MAX_LEN))
latent_noise_std = float(data.get("latent_noise_std", LATENT_NOISE_STD_GEN))
extra_factor = int(data.get("extra_factor", 8))
tol_scaled_override = data.get("tol_scaled", None)
try:
decoder_model, latent_prop_model, scaler_y, selfies_tok, meta, paths = self._load_property_generator(property_name)
except Exception as e:
return {"error": f"Failed to load inverse-design generator bundle: {e}"}
latent_dim = int(getattr(decoder_model, "cl_emb_dim", 600))
y_scaler = getattr(latent_prop_model, "y_scaler", None)
if y_scaler is None:
y_scaler = scaler_y if scaler_y is not None else None
tol_scaled = float(tol_scaled_override) if tol_scaled_override is not None else float(meta.get("tol_scaled", 0.5))
seed_latents: List[np.ndarray] = []
cl_enc = data.get("cl_encoding", None)
if isinstance(cl_enc, dict) and isinstance(cl_enc.get("embedding"), list):
emb = np.asarray(cl_enc["embedding"], dtype=np.float32)
if emb.shape[0] == latent_dim:
seed_latents.append(emb)
seeds_str: List[str] = []
if isinstance(data.get("seed_psmiles_list"), list):
seeds_str.extend([str(x) for x in data["seed_psmiles_list"] if isinstance(x, str)])
if data.get("seed_psmiles"):
seeds_str.append(str(data["seed_psmiles"]))
if data.get("psmiles") and not seeds_str:
seeds_str.append(str(data["psmiles"]))
seeds_str = list(dict.fromkeys(seeds_str))
if seeds_str and not seed_latents:
self._ensure_cl_encoder()
for s in seeds_str:
ex = self._run_data_extraction({}, {"psmiles": s})
if isinstance(ex, dict) and "error" in ex:
continue
cl = self._run_cl_encoding({}, {"data_extraction": ex})
if isinstance(cl, dict) and isinstance(cl.get("embedding"), list):
z = np.asarray(cl["embedding"], dtype=np.float32)
if z.shape[0] == latent_dim:
seed_latents.append(z)
try:
Z_keep, y_s_keep, y_u_keep, target_s = self._sample_latents_for_target(
latent_prop_model=latent_prop_model,
target_value=target_value,
num_samples=num_samples,
latent_dim=latent_dim,
tol_scaled=tol_scaled,
y_scaler=y_scaler,
seed_latents=seed_latents if seed_latents else None,
latent_noise_std=latent_noise_std,
extra_factor=extra_factor,
)
except Exception as e:
return {"error": f"Failed to sample latents conditioned on property: {e}", "paths": paths}
at_bracket_re = re.compile(r"\[(at)\]", flags=re.IGNORECASE)
def _at_to_star_bracket(s: str) -> str:
if not isinstance(s, str) or not s:
return s
return at_bracket_re.sub("[*]", s)
def _is_rdkit_valid(psmiles: str) -> bool:
if Chem is None:
return True
try:
probe = psmiles_to_rdkit_smiles(psmiles)
m = Chem.MolFromSmiles(probe)
return m is not None
except Exception:
return False
requested_k = int(num_samples)
candidates: List[Tuple[int, float, str, str, float, float]] = []
candidates_per_latent = max(1, int(extra_factor))
max_gen_rounds = 4
Z_round, y_s_round, y_u_round = Z_keep, y_s_keep, y_u_keep
for _round in range(max_gen_rounds):
for i in range(Z_round.shape[0]):
z_vec = torch.tensor(Z_round[i], dtype=torch.float32, device=self.config.device).unsqueeze(0)
try:
outs = decoder_model.generate(
z=z_vec,
num_return_sequences=candidates_per_latent,
max_len=max_len,
top_p=top_p,
temperature=temperature,
repetition_penalty=rep_pen,
)
for selfies_str in (outs or []):
psm_raw = pselfies_to_psmiles(selfies_str)
if _is_rdkit_valid(psm_raw):
psm_out = _at_to_star_bracket(psm_raw)
candidates.append(
(
len(psm_out) if isinstance(psm_out, str) else 0,
abs(float(y_s_round[i]) - float(target_s)),
psm_out,
selfies_str,
float(y_s_round[i]),
float(y_u_round[i]),
)
)
except Exception:
continue
if len(candidates) >= requested_k:
break
try:
Z_round, y_s_round, y_u_round, target_s = self._sample_latents_for_target(
latent_prop_model=latent_prop_model,
target_value=target_value,
num_samples=requested_k,
latent_dim=latent_dim,
tol_scaled=tol_scaled,
y_scaler=y_scaler,
seed_latents=seed_latents if seed_latents else None,
latent_noise_std=latent_noise_std,
extra_factor=extra_factor,
)
except Exception:
break
candidates.sort(key=lambda t: (t[0], t[1]))
selected = candidates[:requested_k]
if selected and len(selected) < requested_k:
while len(selected) < requested_k:
selected.append(selected[0])
generated_psmiles: List[str] = [t[2] for t in selected]
selfies_raw: List[str] = [t[3] for t in selected]
decoded_scaled: List[float] = [t[4] for t in selected]
decoded_unscaled: List[float] = [t[5] for t in selected]
return {
"property": property_name,
"target_value": float(target_value),
"num_samples": int(len(generated_psmiles)),
"generated_psmiles": generated_psmiles,
"generated_selfies": selfies_raw,
"latent_property_predictions": {
"scaled": decoded_scaled,
"unscaled": decoded_unscaled,
"target_scaled": float(target_s),
"tol_scaled": float(tol_scaled),
},
"inverse_design_paths": paths,
"selfies_ted_model": meta.get("selfies_ted_model", SELFIES_TED_MODEL_NAME),
"latent_dim": int(latent_dim),
"used_seed_latents": bool(seed_latents),
"seed_psmiles_used": seeds_str,
"rdkit_validation": {
"enabled": bool(Chem is not None),
"note": "Only RDKit-valid generated candidates are returned when RDKit is available."
if Chem is not None else
"RDKit not available; validity filtering could not be applied.",
},
"sampler": {
"MAX_LENGTH": max_len,
"TOP_P": top_p,
"TEMPERATURE": temperature,
"REPETITION_PENALTY": rep_pen,
"LATENT_NOISE_STD": latent_noise_std,
"EXTRA_FACTOR": extra_factor,
},
}
# ----------------- Web tools ----------------- #
def _crossref_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
url = "https://api.crossref.org/works"
params = {
"query.bibliographic": query,
"rows": rows,
"filter": "type:journal-article,from-pub-date:2015-01-01",
}
try:
r = requests.get(url, params=params, timeout=12)
r.raise_for_status()
items = r.json().get("message", {}).get("items", [])
out = []
for it in items:
cr_type = (it.get("type") or "").lower()
if cr_type and cr_type != "journal-article":
continue
title = " ".join(it.get("title", [])) if it.get("title") else ""
doi = normalize_doi(it.get("DOI", "")) or ""
publisher = (it.get("publisher") or "").lower()
if doi and doi.startswith("10.1163/"):
continue
if "brill" in publisher:
continue
pub_year = None
if it.get("published-print") and isinstance(it["published-print"].get("date-parts"), list):
pub_year = it["published-print"]["date-parts"][0][0]
elif it.get("created"):
pub_year = it["created"].get("date-parts", [[None]])[0][0]
doi_url = doi_to_url(doi) if doi else ""
if doi_url and not doi_resolves(doi_url):
doi = ""
doi_url = ""
landing = (it.get("URL") or "") if isinstance(it.get("URL"), str) else ""
out.append({
"title": title,
"doi": doi,
"url": doi_url or landing or "",
"year": pub_year,
"source": "CrossRef",
"type": cr_type,
"publisher": it.get("publisher", ""),
})
return out
except Exception as e:
return [{"error": f"CrossRef query failed: {e}"}]
def _openalex_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
try:
url = "https://api.openalex.org/works"
params = {"search": query, "per-page": rows}
r = requests.get(url, params=params, timeout=12)
r.raise_for_status()
items = r.json().get("results", [])
out = []
for it in items:
oa_type = (it.get("type") or "").lower()
if oa_type and oa_type not in {"journal-article", "proceedings-article", "posted-content"}:
continue
doi = normalize_doi(it.get("doi", "")) or ""
if doi and doi.startswith("10.1163/"):
continue
pl = (it.get("primary_location") or {})
landing = (
pl.get("landing_page_url")
or ((pl.get("source") or {}).get("homepage_url"))
or ""
)
doi_url = doi_to_url(doi) if doi else ""
if doi_url and not doi_resolves(doi_url):
doi = ""
doi_url = ""
out.append({
"title": it.get("title", ""),
"doi": doi,
"url": landing or "",
"year": it.get("publication_year") or (it.get("publication_date", "")[:4]),
"venue": (it.get("host_venue") or {}).get("display_name", ""),
"type": oa_type,
"source": "OpenAlex",
})
return out
except Exception as e:
return [{"error": f"OpenAlex query failed: {e}"}]
def _epmc_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
try:
base = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
params = {"query": query, "format": "json", "pageSize": rows}
r = requests.get(base, params=params, timeout=12)
r.raise_for_status()
hits = r.json().get("resultList", {}).get("result", [])
out = []
for it in hits:
out.append({
"title": it.get("title", ""),
"pmcid": it.get("pmcid", ""),
"year": it.get("pubYear", ""),
"abstract": it.get("abstractText", ""),
"source": "EuropePMC",
})
return out
except Exception as e:
return [{"error": f"Europe PMC query failed: {e}"}]
def _arxiv_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
if BeautifulSoup is None:
return [{"error": "bs4 not installed for arXiv parse"}]
try:
url = "http://export.arxiv.org/api/query"
params = {"search_query": f"all:{query}", "start": 0, "max_results": rows}
r = requests.get(url, params=params, timeout=12, headers={"User-Agent": "PolyOrch/1.0"})
r.raise_for_status()
soup = BeautifulSoup(r.text, "xml")
out = []
for entry in soup.find_all("entry"):
title = (entry.title.text or "").strip()
year = (entry.published.text or "")[:4] if entry.published else ""
link = ""
link_tag = entry.find("link", {"type": "text/html"})
if link_tag and link_tag.get("href"):
link = link_tag["href"]
elif entry.id:
link = entry.id.text
out.append({"title": title, "url": link, "year": year, "source": "arXiv"})
return out
except Exception as e:
return [{"error": f"arXiv query failed: {e}"}]
def _semantic_scholar_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
try:
url = "https://api.semanticscholar.org/graph/v1/paper/search"
params = {"query": query, "limit": rows, "fields": "title,year,externalIds,url,venue,abstract"}
headers = {}
if self.config.semantic_scholar_key:
headers["x-api-key"] = self.config.semantic_scholar_key
r = requests.get(url, params=params, timeout=12, headers=headers)
r.raise_for_status()
papers = r.json().get("data", [])
out = []
for p in papers:
doi = normalize_doi((p.get("externalIds") or {}).get("DOI", "")) or ""
if doi and doi.startswith("10.1163/"):
continue
doi_url = doi_to_url(doi) if doi else ""
if doi_url and not doi_resolves(doi_url):
doi = ""
out.append({
"title": p.get("title", ""),
"doi": doi,
"url": p.get("url", "") or "",
"year": p.get("year", ""),
"venue": p.get("venue", ""),
"abstract": p.get("abstract", ""),
"source": "SemanticScholar",
})
return out
except Exception as e:
return [{"error": f"Semantic Scholar query failed: {e}"}]
def _springer_nature_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
if not self.config.springer_api_key:
return [{"warning": "SPRINGER_NATURE_API_KEY not set; skipping Springer Nature"}]
try:
url = "https://api.springernature.com/metadata/json"
params = {"q": query, "api_key": self.config.springer_api_key, "p": rows}
r = requests.get(url, params=params, timeout=12)
r.raise_for_status()
recs = r.json().get("records", [])
out = []
for rec in recs:
title = rec.get("title", "")
year = (rec.get("publicationDate", "") or "")[:4]
urlp = ""
if rec.get("url"):
urlp = rec["url"][0].get("value", "")
out.append({"title": title, "doi": rec.get("doi", ""), "url": urlp, "year": year, "source": "SpringerNature"})
return out
except Exception as e:
return [{"error": f"Springer Nature query failed: {e}"}]
def _internet_archive_search(self, query: str, rows: int = 6) -> List[Dict[str, Any]]:
if requests is None:
return [{"error": "requests not installed"}]
try:
url = "https://archive.org/advancedsearch.php"
params = {"q": query, "fl[]": "identifier,title,year,creator", "rows": rows, "output": "json"}
r = requests.get(url, params=params, timeout=12)
r.raise_for_status()
docs = r.json().get("response", {}).get("docs", [])
out = []
for d in docs:
ident = d.get("identifier", "")
out.append({
"title": d.get("title", ""),
"url": f"https://archive.org/details/{ident}" if ident else "",
"year": d.get("year", ""),
"source": "InternetArchive",
})
return out
except Exception as e:
return [{"error": f"Internet Archive query failed: {e}"}]
def _fetch_page(self, url: str, max_chars: int = 1200) -> Dict[str, Any]:
if requests is None or BeautifulSoup is None:
return {"error": "requests or bs4 not available"}
try:
r = requests.get(url, timeout=12, headers={"User-Agent": "PolyOrch/1.0"})
r.raise_for_status()
soup = BeautifulSoup(r.text, "html.parser")
title = (soup.title.string or "").strip() if soup.title else ""
paras = [p.get_text(separator=" ", strip=True) for p in soup.find_all("p")]
excerpt = ""
for p in paras:
if len(p) > 50:
excerpt = p
break
if not excerpt:
excerpt = soup.get_text(separator=" ", strip=True)[:max_chars]
return {"title": title, "excerpt": excerpt[:max_chars], "url": url}
except Exception as e:
return {"error": f"Fetch failed: {e}", "url": url}
def _run_web_search(self, step: Dict, data: Dict) -> Dict:
src = (data.get("source", data.get("src", "crossref")) or "").lower()
query = data.get("query", data.get("q", "")) or ""
rows = int(data.get("rows", 6))
if src in ("crossref", "openalex", "epmc", "arxiv", "semanticscholar", "springer", "internetarchive", "all") and not query:
return {"error": f"No query provided for {src} search"}
if src == "crossref":
return {"source": "crossref", "query": query, "results": self._crossref_search(query, rows)}
if src == "openalex":
return {"source": "openalex", "query": query, "results": self._openalex_search(query, rows)}
if src == "epmc":
return {"source": "epmc", "query": query, "results": self._epmc_search(query, rows)}
if src == "arxiv":
return {"source": "arxiv", "query": query, "results": self._arxiv_search(query, rows)}
if src == "semanticscholar":
return {"source": "semanticscholar", "query": query, "results": self._semantic_scholar_search(query, rows)}
if src == "springer":
return {"source": "springer", "query": query, "results": self._springer_nature_search(query, rows)}
if src == "internetarchive":
return {"source": "internetarchive", "query": query, "results": self._internet_archive_search(query, rows)}
if src == "fetch":
url = data.get("url", "")
if not url:
return {"error": "No URL provided for fetch"}
return {"source": "fetch", "url": url, "page": self._fetch_page(url)}
if src == "all":
aggregated = {
"crossref": self._crossref_search(query, rows),
"openalex": self._openalex_search(query, rows),
"epmc": self._epmc_search(query, rows),
"arxiv": self._arxiv_search(query, rows),
"semanticscholar": self._semantic_scholar_search(query, rows),
"springer": self._springer_nature_search(query, rows),
"internetarchive": self._internet_archive_search(query, rows),
}
return {"source": "all", "query": query, "results": aggregated}
return {"error": f"Unsupported web_search source: {src}"}
# =============================================================================
# REPORT GENERATION
# =============================================================================
def generate_report(self, data: Dict[str, Any]) -> Dict[str, Any]:
payload = dict(data or {})
summary: Dict[str, Any] = {}
prop = payload.get("property") or payload.get("property_name")
if prop:
payload["property"] = prop
if not payload.get("property"):
qtxt = payload.get("questions") or payload.get("question") or ""
inferred_prop = infer_property_from_text(qtxt)
if inferred_prop:
payload["property"] = inferred_prop
psmiles = payload.get("psmiles") or payload.get("seed_psmiles")
if psmiles:
payload["psmiles"] = psmiles
if payload.get("target_value", None) is None:
qtxt = payload.get("questions") or payload.get("question") or ""
inferred_tgt = infer_target_value_from_text(qtxt, payload.get("property"))
if inferred_tgt is not None:
payload["target_value"] = float(inferred_tgt)
if psmiles and "data_extraction" not in payload:
ex = self._run_data_extraction({"step": -1}, payload)
payload["data_extraction"] = ex
summary["data_extraction"] = ex
if "data_extraction" in payload and "cl_encoding" not in payload:
cl = self._run_cl_encoding({"step": -1}, payload)
payload["cl_encoding"] = cl
summary["cl_encoding"] = cl
if payload.get("property") and "property_prediction" not in payload:
pp = self._run_property_prediction({"step": -1}, payload)
payload["property_prediction"] = pp
summary["property_prediction"] = pp
do_gen = bool(payload.get("generate", False)) or (payload.get("target_value", None) is not None)
if do_gen and payload.get("property") and payload.get("target_value", None) is not None:
gen = self._run_polymer_generation({"step": -1}, payload)
payload["polymer_generation"] = gen
summary["generation"] = gen
q = payload.get("query") or payload.get("literature_query")
src = payload.get("source") or "all"
if q:
ws = self._run_web_search({"step": -1}, {"source": src, "query": q, "rows": int(payload.get("rows", 6))})
payload["web_search"] = ws
summary["web_search"] = ws
report = {
"summary": summary,
"tool_outputs": {
"data_extraction": payload.get("data_extraction"),
"cl_encoding": payload.get("cl_encoding"),
"property_prediction": payload.get("property_prediction"),
"polymer_generation": payload.get("polymer_generation"),
"web_search": payload.get("web_search"),
"rag_retrieval": payload.get("rag_retrieval"),
},
"questions": payload.get("questions") or payload.get("question") or "",
}
report = _attach_source_domains(report)
report = _index_citable_sources(report)
report = _assign_tool_tags_to_report(report)
return report
def _run_report_generation(self, step: Dict, data: Dict) -> Dict[str, Any]:
return self.generate_report(data)
# =============================================================================
# COMPOSER
# =============================================================================
def compose_gpt_style_answer(
self,
report: Dict[str, Any],
case_brief: str = "",
questions: str = "",
) -> Tuple[str, List[str]]:
imgs: List[str] = []
if isinstance(report, dict):
report = _attach_source_domains(report)
report = _index_citable_sources(report)
report = _assign_tool_tags_to_report(report)
if self.openai_client is None:
md_lines = []
if case_brief:
md_lines.append(case_brief.strip())
md_lines.append("")
if questions:
md_lines.append(questions.strip())
md_lines.append("")
md_lines.append("```json")
try:
md_lines.append(json.dumps(report, indent=2, ensure_ascii=False))
except Exception:
md_lines.append(str(report))
md_lines.append("```")
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
if verb:
md_lines.append("\n---\n\n## Tool outputs (verbatim)\n")
md_lines.append(verb)
return "\n".join(md_lines), imgs
try:
prompt = (
"You are PolyAgent - consider yourself as an expert in polymer science. Answer the user's questions using ONLY the provided report.\n"
"Do NOT follow a fixed template. Let the structure be driven by the user's questions.\n\n"
"CITATION RULES (STRICT):\n"
"- Tool facts: when you use any information from a tool output, cite it as [T] (exactly; no numbering).\n"
"- Literature/web facts: cite using the COMPLETE DOI URL (https://doi.org/...) in brackets as a Markdown hyperlink.\n"
" The bracket text MUST be the full DOI URL (or the best URL if DOI is unavailable), and the href MUST be that same URL.\n"
"- NEVER use numbered citations like [1], [2] for papers.\n"
"- Every literature/web/RAG citation MUST be an inline Markdown hyperlink placed immediately after the claim.\n"
"- You are FORBIDDEN from adding any 'References', 'Sources', 'Bibliography', or 'Works Cited' section.\n"
"- Distribute citations across the answer (do not cluster them in one place).\n"
"- NON-DUPLICATES: Do not repeat the same paper link. Each DOI/URL may appear at most once in the entire answer.\n"
"- Each major section should include at least 1 inline literature citation when relevant.\n"
"- Do NOT invent DOIs, URLs, titles, or sources.\n\n"
"OUTPUT RULES (STRICT):\n"
"- If a numeric value is not present in the report, write 'not available'.\n"
"- Preserve polymer endpoint tokens exactly as '[*]' in any pSMILES/SMILES shown.\n"
"- To prevent markdown mangling, put any pSMILES/SMILES inside code formatting.\n"
"- Do not rewrite or tweak any tool outputs; if you refer to them, reference them by tag (e.g., [T]).\n\n"
f"CASE BRIEF:\n{case_brief}\n\n"
f"QUESTIONS:\n{questions}\n\n"
f"REPORT (JSON):\n{json.dumps(report, ensure_ascii=False)}\n"
)
resp = self.openai_client.chat.completions.create(
model=self.config.model,
messages=[
{"role": "system", "content": "Return a single markdown answer."},
{"role": "user", "content": prompt},
],
temperature=0.3,
max_tokens=2200,
)
txt = resp.choices[0].message.content or ""
try:
min_cites = _infer_required_citation_count(questions or "", default_n=10)
txt = _ensure_distributed_inline_citations(txt, report, min_needed=min_cites)
except Exception:
pass
try:
txt = _normalize_and_dedupe_literature_links(txt, report)
except Exception:
pass
try:
txt = autolink_doi_urls(txt)
except Exception:
pass
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
if verb:
txt = txt.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
return txt, imgs
except Exception as e:
md = f"OpenAI compose failed: {e}\n\n```json\n{json.dumps(report, indent=2, ensure_ascii=False)}\n```"
verb = _render_tool_outputs_verbatim_md(report) if isinstance(report, dict) else ""
if verb:
md = md.rstrip() + "\n\n---\n\n## Tool outputs (verbatim)\n\n" + verb
return md, imgs
# =============================================================================
# VISUAL TOOLS
# =============================================================================
def _run_mol_render(self, step: Dict, data: Dict) -> Dict[str, Any]:
out_dir = Path("viz")
out_dir.mkdir(parents=True, exist_ok=True)
if Chem is None or Draw is None:
return {"error": "RDKit not installed"}
p = data.get("psmiles") or data.get("seed_psmiles")
if not p:
return {"error": "no psmiles"}
mol = Chem.MolFromSmiles(psmiles_to_rdkit_smiles(p))
if mol is None:
return {"error": "invalid psmiles"}
img = Draw.MolToImage(mol, size=(600, 400))
png = str(out_dir / "mol.png")
img.save(png)
return {"png_path": png, "legend": p}
def _run_gen_grid(self, step: Dict, data: Dict) -> Dict[str, Any]:
out_dir = Path("viz")
out_dir.mkdir(parents=True, exist_ok=True)
if Chem is None or Draw is None:
return {"error": "RDKit not installed"}
p_list = data.get("psmiles_list")
if p_list is None:
gen = data.get("polymer_generation") or {}
p_list = gen.get("generated_psmiles", [])
if not p_list:
return {"error": "no psmiles_list provided and no generated_psmiles found"}
mols = []
legends = []
for i, p in enumerate(p_list, 1):
m = Chem.MolFromSmiles(psmiles_to_rdkit_smiles(p)) if p else None
if m is None:
continue
mols.append(m)
legends.append(f"{i}")
if not mols:
return {"error": "no valid molecules to render"}
img = Draw.MolsToGridImage(mols, molsPerRow=min(4, len(mols)), subImgSize=(300, 220), legends=legends, useSVG=False)
png = str(out_dir / "gen_grid.png")
img.save(png)
return {"png_path": png, "n": len(mols)}
def _run_prop_attribution(self, step: Dict, data: Dict) -> Dict[str, Any]:
out_dir = Path("viz")
out_dir.mkdir(parents=True, exist_ok=True)
if Chem is None or Draw is None:
return {"error": "RDKit not installed"}
p = data.get("psmiles") or data.get("seed_psmiles")
prop = canonical_property_name(data.get("property") or data.get("property_name") or "glass transition")
top_k = int(data.get("top_k_atoms", data.get("top_k", 12)))
min_rel_importance = float(data.get("min_rel_importance", 0.25))
min_abs_importance = float(data.get("min_abs_importance", 0.0))
if prop not in self.PROPERTY_HEAD_PATHS:
return {"error": f"Unsupported property for attribution: {prop}"}
if not p:
return {"error": "no psmiles"}
mol = Chem.MolFromSmiles(psmiles_to_rdkit_smiles(p))
if mol is None:
return {"error": "invalid psmiles"}
num_atoms = mol.GetNumAtoms()
if num_atoms <= 0:
return {"error": "molecule has no atoms"}
base_res = self._run_property_prediction({}, {"psmiles": p, "property": prop})
if "error" in base_res or "predictions" not in base_res:
return {"error": f"Baseline prediction failed: {base_res.get('error', 'unknown error')}"}
baseline = base_res["predictions"].get(prop)
if not isinstance(baseline, (float, int)):
return {"error": "Baseline prediction not numeric"}
scores: Dict[int, float] = {}
for idx in range(num_atoms):
try:
tmp = Chem.RWMol(mol)
tmp.GetAtomWithIdx(idx).SetAtomicNum(0) # wildcard
mutated = tmp.GetMol()
mut_smiles = Chem.MolToSmiles(mutated)
mut_psmiles = normalize_generated_psmiles_out(mut_smiles)
except Exception:
scores[idx] = 0.0
continue
mut_res = self._run_property_prediction({}, {"psmiles": mut_psmiles, "property": prop})
mut_val = (mut_res.get("predictions") or {}).get(prop) if isinstance(mut_res, dict) else None
if not isinstance(mut_val, (float, int)):
scores[idx] = 0.0
else:
scores[idx] = float(baseline) - float(mut_val)
max_abs = max((abs(v) for v in scores.values()), default=0.0)
rel_thresh = (min_rel_importance * max_abs) if max_abs > 0 else 0.0
thresh = max(float(min_abs_importance), float(rel_thresh))
ranked = sorted(scores.items(), key=lambda kv: abs(kv[1]), reverse=True)
k_cap = max(1, min(top_k, num_atoms))
selected = [i for i, v in ranked if abs(v) >= thresh]
selected = selected[:k_cap]
if not selected and ranked:
selected = [ranked[0][0]]
atom_colors: Dict[int, tuple] = {}
sel_scores = np.array([scores[i] for i in selected], dtype=float)
if cm is not None and sel_scores.size > 0:
denom = (np.max(sel_scores) - np.min(sel_scores))
if denom == 0:
norm = np.full_like(sel_scores, 0.5)
else:
norm = (sel_scores - np.min(sel_scores)) / denom
cmap = cm.get_cmap("coolwarm")
for i, n in zip(selected, norm):
r, g, b, _ = cmap(float(n))
atom_colors[i] = (float(r), float(g), float(b))
else:
max_mag = max(abs(v) for v in sel_scores) if sel_scores.size else 1.0
for i in selected:
v = scores[i] / (max_mag or 1.0)
if v >= 0:
atom_colors[i] = (1.0, 1.0 - 0.7 * v, 1.0 - 0.7 * v)
else:
vv = abs(v)
atom_colors[i] = (1.0 - 0.7 * vv, 1.0 - 0.7 * vv, 1.0)
try:
img = Draw.MolToImage(
mol,
size=(700, 450),
highlightAtoms=selected,
highlightAtomColors=atom_colors,
)
png = str(out_dir / "prop_attribution.png")
img.save(png)
return {
"png_path": png,
"per_atom_scores": {int(i): float(v) for i, v in scores.items()},
"highlighted_atoms": selected,
"baseline_prediction": float(baseline),
"property": prop,
"method": "leave_one_atom_out_occlusion_thresholded_topk",
"top_k_cap": int(k_cap),
"selected_k": int(len(selected)),
"min_rel_importance": float(min_rel_importance),
"min_abs_importance": float(min_abs_importance),
"used_threshold": float(thresh),
}
except Exception as e:
return {"error": f"prop_attribution rendering failed: {e}"}
def process_query(self, user_query: str, user_inputs: Dict[str, Any] = None) -> Dict[str, Any]:
plan = self.analyze_query(user_query)
results = self.execute_plan(plan, user_inputs)
return results
if __name__ == "__main__":
cfg = OrchestratorConfig(paths=PathsConfig())
orch = PolymerOrchestrator(cfg)
print("PolymerOrchestrator ready.")