diff --git a/literature/__init__.py b/literature/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d72852ed45c2c1e60a6b43e732915fd817a11e --- /dev/null +++ b/literature/__init__.py @@ -0,0 +1,100 @@ +"""Literature mining package for project-based extraction workflows.""" + +from .schemas import ( + ContextualizedValue, + DataQuality, + ExperimentalConditions, + ExtractionResult, + LiteratureEvidenceRecord, + LiteratureQuerySpec, + LiteratureSupportSummary, + PaperMetadata, + PaperCardResult, + PaperSource, + PolymerDataPoint, + QueryMode, + ReviewStatus, +) +from .property_registry import ( + PROPERTY_CATALOG, + PLATFORM_PROPERTY_KEYS, + TEMPLATES, + TEMPLATE_LABELS, + build_extraction_prompt, + detect_property_keys, + normalize_property_key, + property_display_name, +) +from .quality import QualityAssessor, QualityReport +from .standardizer import StandardizationResult, UnitStandardizer, normalize_minus_signs +from .clarifier import ClarifierAgent, QueryAnalysis +from .evaluation import evaluate_predictions, load_json_records + +try: + from .config import LiteratureConfig, get_config +except Exception: # pragma: no cover - optional runtime dependency + LiteratureConfig = None # type: ignore + get_config = None # type: ignore + +try: + from .discovery import PaperDiscoveryAgent +except Exception: # pragma: no cover - optional runtime dependency + PaperDiscoveryAgent = None # type: ignore + +try: + from .retrieval import PDFRetriever, extract_text_from_pdf +except Exception: # pragma: no cover - optional runtime dependency + PDFRetriever = None # type: ignore + extract_text_from_pdf = None # type: ignore + +try: + from .extraction import ContextualizedExtractor, DataExtractor +except Exception: # pragma: no cover - optional runtime dependency + ContextualizedExtractor = None # type: ignore + DataExtractor = None # type: ignore + +try: + from .converters import to_experiment_result +except Exception: # pragma: no cover - optional runtime dependency + to_experiment_result = None # type: ignore + +__all__ = [ + "LiteratureConfig", + "get_config", + "PaperMetadata", + "PaperSource", + "PolymerDataPoint", + "ExtractionResult", + "DataQuality", + "ContextualizedValue", + "ExperimentalConditions", + "LiteratureQuerySpec", + "PaperCardResult", + "LiteratureEvidenceRecord", + "LiteratureSupportSummary", + "QueryMode", + "ReviewStatus", + "PaperDiscoveryAgent", + "PDFRetriever", + "extract_text_from_pdf", + "DataExtractor", + "ContextualizedExtractor", + "QualityAssessor", + "QualityReport", + "UnitStandardizer", + "normalize_minus_signs", + "StandardizationResult", + "ClarifierAgent", + "QueryAnalysis", + "evaluate_predictions", + "load_json_records", + "to_experiment_result", + "PROPERTY_CATALOG", + "PLATFORM_PROPERTY_KEYS", + "TEMPLATES", + "TEMPLATE_LABELS", + "build_extraction_prompt", + "detect_property_keys", + "normalize_property_key", + "property_display_name", +] diff --git a/literature/__pycache__/__init__.cpython-310.pyc b/literature/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4556c4ebce180d059d1defd6181acbd92a91b6e Binary files /dev/null and b/literature/__pycache__/__init__.cpython-310.pyc differ diff --git a/literature/__pycache__/__init__.cpython-313.pyc b/literature/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cfbffb8415bd8abb1d86a732f04547fdb7b04f7 Binary files /dev/null and b/literature/__pycache__/__init__.cpython-313.pyc differ diff --git a/literature/__pycache__/clarifier.cpython-310.pyc b/literature/__pycache__/clarifier.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b0943db0797751241266cd93cc5dee73aa2c5ad Binary files /dev/null and b/literature/__pycache__/clarifier.cpython-310.pyc differ diff --git a/literature/__pycache__/clarifier.cpython-313.pyc b/literature/__pycache__/clarifier.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac5da98afb80bc6ba15a3c1862b0f63c20b85bec Binary files /dev/null and b/literature/__pycache__/clarifier.cpython-313.pyc differ diff --git a/literature/__pycache__/config.cpython-310.pyc b/literature/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..362b3253c50609354a09e985f1fcc9442478d2c2 Binary files /dev/null and b/literature/__pycache__/config.cpython-310.pyc differ diff --git a/literature/__pycache__/config.cpython-313.pyc b/literature/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f43f310d224070434f7877852295bae2c37e03 Binary files /dev/null and b/literature/__pycache__/config.cpython-313.pyc differ diff --git a/literature/__pycache__/converters.cpython-310.pyc b/literature/__pycache__/converters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6872af9ba49b0ff1ec3243a9211d040d0eec49e5 Binary files /dev/null and b/literature/__pycache__/converters.cpython-310.pyc differ diff --git a/literature/__pycache__/converters.cpython-313.pyc b/literature/__pycache__/converters.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fb7d7586ac2098e3ae932751b8d42a77c4c0293 Binary files /dev/null and b/literature/__pycache__/converters.cpython-313.pyc differ diff --git a/literature/__pycache__/discovery.cpython-310.pyc b/literature/__pycache__/discovery.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42106f93c4e5a336e50ab215865e5652c18720f Binary files /dev/null and b/literature/__pycache__/discovery.cpython-310.pyc differ diff --git a/literature/__pycache__/discovery.cpython-313.pyc b/literature/__pycache__/discovery.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b21f3ed7292e001cbd59401a3f61af2743f8cb03 Binary files /dev/null and b/literature/__pycache__/discovery.cpython-313.pyc differ diff --git a/literature/__pycache__/evaluation.cpython-310.pyc b/literature/__pycache__/evaluation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..777a3142efe114cd63376774f738abf54ab966b5 Binary files /dev/null and b/literature/__pycache__/evaluation.cpython-310.pyc differ diff --git a/literature/__pycache__/extraction.cpython-310.pyc b/literature/__pycache__/extraction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53980ecc8d37a5cb6ad0513b9c24830628587e19 Binary files /dev/null and b/literature/__pycache__/extraction.cpython-310.pyc differ diff --git a/literature/__pycache__/extraction.cpython-313.pyc b/literature/__pycache__/extraction.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9db63b9a9d1bf430e2348126118cb530b16bd62d Binary files /dev/null and b/literature/__pycache__/extraction.cpython-313.pyc differ diff --git a/literature/__pycache__/graph.cpython-313.pyc b/literature/__pycache__/graph.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a49467191be408c476a62276ea760b533e76dcf6 Binary files /dev/null and b/literature/__pycache__/graph.cpython-313.pyc differ diff --git a/literature/__pycache__/property_registry.cpython-310.pyc b/literature/__pycache__/property_registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..868b5e8f7061fee6b4ff3dcf79fba05543707ab8 Binary files /dev/null and b/literature/__pycache__/property_registry.cpython-310.pyc differ diff --git a/literature/__pycache__/property_registry.cpython-313.pyc b/literature/__pycache__/property_registry.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cefff2e156fa68030a37549521615eb8730c4e4 Binary files /dev/null and b/literature/__pycache__/property_registry.cpython-313.pyc differ diff --git a/literature/__pycache__/quality.cpython-310.pyc b/literature/__pycache__/quality.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8bf0fb0f2dcac86d56992a8fef712b85bf9d5ef Binary files /dev/null and b/literature/__pycache__/quality.cpython-310.pyc differ diff --git a/literature/__pycache__/quality.cpython-313.pyc b/literature/__pycache__/quality.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..891258ce9a0e306a726f8a7136bb9c99f156c7ee Binary files /dev/null and b/literature/__pycache__/quality.cpython-313.pyc differ diff --git a/literature/__pycache__/retrieval.cpython-310.pyc b/literature/__pycache__/retrieval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f63c3715cac14e2bdbb0b65152f382cefec33ae4 Binary files /dev/null and b/literature/__pycache__/retrieval.cpython-310.pyc differ diff --git a/literature/__pycache__/retrieval.cpython-313.pyc b/literature/__pycache__/retrieval.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4396385fbe90e5a4e273a79830d8ec8010dc8bdd Binary files /dev/null and b/literature/__pycache__/retrieval.cpython-313.pyc differ diff --git a/literature/__pycache__/schemas.cpython-310.pyc b/literature/__pycache__/schemas.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d1e166e7ff4e5195d41d6ce3b0b9640418a71e3 Binary files /dev/null and b/literature/__pycache__/schemas.cpython-310.pyc differ diff --git a/literature/__pycache__/schemas.cpython-313.pyc b/literature/__pycache__/schemas.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a266923f42733a74b7253c69f0ff72026ee432c Binary files /dev/null and b/literature/__pycache__/schemas.cpython-313.pyc differ diff --git a/literature/__pycache__/standardizer.cpython-310.pyc b/literature/__pycache__/standardizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..722d545d69f66e37cc6a3219cdec07b2ab87a21c Binary files /dev/null and b/literature/__pycache__/standardizer.cpython-310.pyc differ diff --git a/literature/__pycache__/standardizer.cpython-313.pyc b/literature/__pycache__/standardizer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42bc6c541d9880e38e639167f7c058e4f73c9d92 Binary files /dev/null and b/literature/__pycache__/standardizer.cpython-313.pyc differ diff --git a/literature/clarifier.py b/literature/clarifier.py new file mode 100644 index 0000000000000000000000000000000000000000..c21c2d42de9edac30a3a6b6d457929fc11fecb8a --- /dev/null +++ b/literature/clarifier.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List + +from .property_registry import detect_property_keys, property_display_name + + +POLYMER_KEYWORDS = { + "polymer", + "polyimide", + "peek", + "polyethylene", + "pedot", + "pedot:pss", + "p3ht", + "smiles", +} + +CONDITION_KEYWORDS = { + "anneal", + "annealing", + "solvent", + "dopant", + "doping", + "spin coat", + "temperature", + "thickness", + "pressure", + "humidity", + "method", +} + + +@dataclass +class QueryAnalysis: + original_query: str + detected_polymers: List[str] + detected_properties: List[str] + detected_conditions: List[str] + suggestions: List[str] + clarification_required: bool + status: str + + def to_payload(self) -> Dict[str, object]: + return { + "original_query": self.original_query, + "detected_polymers": self.detected_polymers, + "detected_properties": self.detected_properties, + "detected_conditions": self.detected_conditions, + "suggestions": self.suggestions, + "clarification_required": self.clarification_required, + "status": self.status, + } + + +class ClarifierAgent: + """ + Lightweight clarifier for production search flows. + It nudges users toward material + property + condition context without + blocking valid free-form task queries. + """ + + def analyze(self, query: str) -> QueryAnalysis: + q = (query or "").lower() + polymers = [keyword for keyword in POLYMER_KEYWORDS if keyword in q] + properties = detect_property_keys(query or "") + conditions = [keyword for keyword in CONDITION_KEYWORDS if keyword in q] + + suggestions: List[str] = [] + if not polymers: + suggestions.append("Add a target polymer or material name.") + if not properties: + suggestions.append("Specify a key property focus, e.g. " + property_display_name("tg") + ".") + if not conditions: + suggestions.append("Add one processing or measurement condition if available.") + + clarification_required = (not polymers) and (not properties) + status = "pending_clarification" if clarification_required else "ready" + + return QueryAnalysis( + original_query=query, + detected_polymers=polymers, + detected_properties=properties, + detected_conditions=conditions, + suggestions=suggestions, + clarification_required=clarification_required, + status=status, + ) diff --git a/literature/config.py b/literature/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed2810a116bc6d97cbc28c7773b9fb0ff7ab52e --- /dev/null +++ b/literature/config.py @@ -0,0 +1,71 @@ +""" +Configuration management for Literature Discovery module. +Uses pydantic-settings for environment variable loading. +""" +from typing import Optional, List +from pydantic import Field +from pydantic_settings import BaseSettings +from functools import lru_cache + + +class LiteratureConfig(BaseSettings): + """Literature mining configuration.""" + + # API Keys + pubmed_email: str = Field(default="scholar@university.edu", alias="PUBMED_EMAIL") + pubmed_api_key: Optional[str] = Field(default=None, alias="PUBMED_API_KEY") + semantic_scholar_api_key: Optional[str] = Field(default=None, alias="SEMANTIC_SCHOLAR_API_KEY") + gemini_api_key: Optional[str] = Field(default=None, alias="GEMINI_API_KEY") + openai_api_key: Optional[str] = Field(default=None, alias="MY_OPEN_WEBUI_API_KEY") + openai_base_url: Optional[str] = Field(default=None, alias="OPENAI_BASE_URL") + pageindex_api_key: Optional[str] = Field(default=None, alias="PAGEINDEX_API_KEY") + + # LLM Configuration + llm_model: str = Field(default="gemini/gemini-2.0-flash", alias="LLM_MODEL") + embedding_model: str = Field(default="gemini/text-embedding-004") + llm_temperature: float = Field(default=0.1, ge=0.0, le=1.0) + llm_max_tokens: int = Field(default=4096) + + # Search Configuration + default_search_limit: int = Field(default=20) + pubmed_enabled: bool = Field(default=True) + arxiv_enabled: bool = Field(default=True) + semantic_scholar_enabled: bool = Field(default=True) # Now enabled + + # Rate Limiting (Semantic Scholar: 1 req/sec) + semantic_scholar_delay_s: float = Field(default=1.5) # Slightly over 1s for safety + pubmed_delay_s: float = Field(default=0.5) + + # Storage + pdf_storage_dir: str = Field(default="data/literature/raw_pdfs") + database_path: str = Field(default="data/literature/papers.db") + + # Processing + max_concurrent_downloads: int = Field(default=3) + extraction_timeout_s: int = Field(default=120) + + # PDF Download Headers (for avoiding 403) + user_agent: str = Field( + default="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + ) + + # Target Polymers (for focused search) + target_polymers: List[str] = Field( + default=["PEDOT:PSS", "P3HT", "PBTTT", "P(NDI2OD-T2)", "PDPP-4T"] + ) + + # Extraction strategy: "paperqa" or "simple" + extraction_strategy: str = Field(default="simple") + literature_model_options: str = Field(default="[]", alias="LITERATURE_MODEL_OPTIONS") + + model_config = { + "env_file": ".env", + "env_file_encoding": "utf-8", + "extra": "ignore", + } + + +@lru_cache() +def get_config() -> LiteratureConfig: + """Get configuration singleton.""" + return LiteratureConfig() diff --git a/literature/converters.py b/literature/converters.py new file mode 100644 index 0000000000000000000000000000000000000000..360416dad715b226a6cd794660b4256605d2e2cc --- /dev/null +++ b/literature/converters.py @@ -0,0 +1,56 @@ +""" +Data model converters. + +This module is now schema-optional: +- If legacy `src.utils.schema` exists, returns (Experiment, Result) objects. +- Otherwise returns two plain dict payloads for compatibility. +""" +import time +from typing import Any, Dict, Tuple + +from .schemas import PolymerDataPoint + +try: + from src.utils.schema import Experiment, Result # type: ignore + HAS_LEGACY_SCHEMA = True +except Exception: + Experiment = None # type: ignore + Result = None # type: ignore + HAS_LEGACY_SCHEMA = False + + +def to_experiment_result(dp: PolymerDataPoint) -> Tuple[Any, Any]: + exp_id = f"lit_{dp.source_paper_id}_{int(time.time() * 1000)}" + exp_payload: Dict[str, Any] = { + "id": exp_id, + "polymer_id": dp.polymer_name, + "concentration_mg_ml": dp.concentration_mg_ml or 0.0, + "spin_speed_rpm": dp.spin_speed_rpm or 0, + "annealing_temp_c": dp.annealing_temp_c or 0.0, + "annealing_time_min": dp.annealing_time_min or 0.0, + "status": "completed", + "metadata": { + "dopant": dp.dopant, + "dopant_ratio": dp.dopant_ratio, + "solvent": dp.solvent, + "source_paper_id": dp.source_paper_id, + "source_table": dp.source_table_or_figure, + "quality_tier": dp.quality_tier.value, + "extraction_confidence": dp.extraction_confidence, + "film_thickness_nm": dp.film_thickness_nm, + "seebeck_coefficient_uv_k": dp.seebeck_coefficient_uv_k, + "power_factor_uw_m_k2": dp.power_factor_uw_m_k2, + }, + } + res_payload: Dict[str, Any] = { + "experiment_id": exp_id, + "ec_s_cm": dp.electrical_conductivity_s_cm or 0.0, + "tc_w_mk": dp.thermal_conductivity_w_mk, + "xrd_crystallinity": dp.xrd_crystallinity_percent, + "xrd_pi_stacking_angstrom": dp.xrd_pi_stacking_angstrom, + "source": "literature", + } + + if HAS_LEGACY_SCHEMA: + return Experiment(**exp_payload), Result(**res_payload) # type: ignore + return exp_payload, res_payload diff --git a/literature/discovery.py b/literature/discovery.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc7bf4d040e34036b7bcd943a3aa3d4a20eb635 --- /dev/null +++ b/literature/discovery.py @@ -0,0 +1,380 @@ +""" +Multi-source paper discovery module. +Implements PubMed, ArXiv, and Semantic Scholar search. +Uses synchronous code for MVP simplicity. +""" +import logging +import time +from typing import List, Optional + +import arxiv +from Bio import Entrez + +from .schemas import PaperMetadata, PaperSource +from .config import get_config + +logger = logging.getLogger(__name__) +_SEMANTIC_SCHOLAR_IMPORT_MISSING_LOGGED = False + + +class ArxivSearcher: + """ArXiv paper searcher.""" + + def __init__(self) -> None: + self.client = arxiv.Client() + + def search(self, query: str, limit: int = 10) -> List[PaperMetadata]: + """ + Search ArXiv for papers. + + Args: + query: Search query string + limit: Maximum number of results + + Returns: + List of PaperMetadata objects + """ + logger.info(f"Searching ArXiv: '{query}' (limit={limit})") + + search = arxiv.Search( + query=query, + max_results=limit, + sort_by=arxiv.SortCriterion.Relevance + ) + + papers: List[PaperMetadata] = [] + try: + for result in self.client.results(search): + # Extract arxiv ID without version + arxiv_id = result.entry_id.split('/')[-1].split('v')[0] + + paper = PaperMetadata( + id=f"arxiv_{arxiv_id}", + title=result.title, + authors=[a.name for a in result.authors], + year=result.published.year if result.published else None, + doi=result.doi, + abstract=result.summary, + venue="arXiv", + citation_count=None, + is_open_access=True, + source=PaperSource.ARXIV, + url=result.entry_id, + landing_url=result.entry_id, + pdf_url=result.pdf_url, + ) + papers.append(paper) + except Exception as e: + logger.error(f"ArXiv search failed: {e}") + + logger.info(f"ArXiv returned {len(papers)} papers") + return papers + + +class PubMedSearcher: + """PubMed paper searcher using Biopython Entrez.""" + + def __init__(self) -> None: + config = get_config() + Entrez.email = config.pubmed_email + if config.pubmed_api_key: + Entrez.api_key = config.pubmed_api_key + self.delay = config.pubmed_delay_s + + def search(self, query: str, limit: int = 10) -> List[PaperMetadata]: + """ + Search PubMed for papers. + + Args: + query: Search query string + limit: Maximum number of results + + Returns: + List of PaperMetadata objects + """ + logger.info(f"Searching PubMed: '{query}' (limit={limit})") + + try: + # Step 1: Search for IDs + handle = Entrez.esearch(db="pubmed", term=query, retmax=limit) + record = Entrez.read(handle) + handle.close() + + id_list = record.get("IdList", []) + if not id_list: + logger.info("PubMed returned 0 papers") + return [] + + time.sleep(self.delay) + + # Step 2: Fetch details in XML format + handle = Entrez.efetch( + db="pubmed", + id=id_list, + rettype="xml", + retmode="xml" + ) + records = Entrez.read(handle) + handle.close() + + papers: List[PaperMetadata] = [] + for article in records.get("PubmedArticle", []): + try: + paper = self._parse_pubmed_article(article) + if paper: + papers.append(paper) + except Exception as e: + logger.warning(f"Failed to parse PubMed article: {e}") + + logger.info(f"PubMed returned {len(papers)} papers") + return papers + + except Exception as e: + logger.error(f"PubMed search failed: {e}") + return [] + + def _parse_pubmed_article(self, article: dict) -> Optional[PaperMetadata]: + """Parse a single PubMed article into PaperMetadata.""" + medline = article.get("MedlineCitation", {}) + article_data = medline.get("Article", {}) + + # Extract PMID + pmid = str(medline.get("PMID", "")) + if not pmid: + return None + + # Extract title + title = article_data.get("ArticleTitle", "Unknown Title") + if isinstance(title, list): + title = " ".join(str(t) for t in title) + + # Extract authors + authors: List[str] = [] + author_list = article_data.get("AuthorList", []) + for author in author_list: + if isinstance(author, dict): + last_name = author.get("LastName", "") + fore_name = author.get("ForeName", "") + if last_name: + authors.append(f"{fore_name} {last_name}".strip()) + + # Extract year + year = None + pub_date = article_data.get("Journal", {}).get("JournalIssue", {}).get("PubDate", {}) + if "Year" in pub_date: + try: + year = int(pub_date["Year"]) + except (ValueError, TypeError): + pass + + # Extract abstract + abstract = "" + abstract_text = article_data.get("Abstract", {}).get("AbstractText", []) + if isinstance(abstract_text, list): + abstract = " ".join(str(t) for t in abstract_text) + elif isinstance(abstract_text, str): + abstract = abstract_text + + # Extract DOI + doi = None + id_list = article_data.get("ELocationID", []) + for eid in id_list: + if hasattr(eid, "attributes") and eid.attributes.get("EIdType") == "doi": + doi = str(eid) + break + + journal = article_data.get("Journal", {}) + journal_title = journal.get("Title") + + return PaperMetadata( + id=f"pubmed_{pmid}", + title=str(title), + authors=authors, + year=year, + doi=doi, + abstract=abstract, + venue=str(journal_title) if journal_title else None, + citation_count=None, + is_open_access=None, + source=PaperSource.PUBMED, + url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", + landing_url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/", + ) + + +class SemanticScholarSearcher: + """Semantic Scholar paper searcher (with rate limiting).""" + + def __init__(self) -> None: + config = get_config() + self.api_key = config.semantic_scholar_api_key + self.delay = config.semantic_scholar_delay_s + + def search(self, query: str, limit: int = 10) -> List[PaperMetadata]: + """ + Search Semantic Scholar for papers. + Rate limited to avoid 403 errors. + + Args: + query: Search query string + limit: Maximum number of results + + Returns: + List of PaperMetadata objects + """ + logger.info(f"Searching Semantic Scholar: '{query}' (limit={limit})") + + # Lazy import to avoid dependency issues + try: + from semanticscholar import SemanticScholar + except ImportError: + global _SEMANTIC_SCHOLAR_IMPORT_MISSING_LOGGED + if not _SEMANTIC_SCHOLAR_IMPORT_MISSING_LOGGED: + logger.debug("semanticscholar package not installed; Semantic Scholar source disabled.") + _SEMANTIC_SCHOLAR_IMPORT_MISSING_LOGGED = True + return [] + + time.sleep(self.delay) # Initial delay + + try: + client = SemanticScholar(api_key=self.api_key) + results = client.search_paper( + query, + limit=limit, + fields=['title', 'abstract', 'authors', 'year', 'externalIds', 'url', 'isOpenAccess', 'openAccessPdf', 'venue', 'citationCount'] + ) + + papers: List[PaperMetadata] = [] + for item in results: + if len(papers) >= limit: + break + + # Get PDF URL if available + pdf_url = None + if item.openAccessPdf and isinstance(item.openAccessPdf, dict): + pdf_url = item.openAccessPdf.get('url') + + paper = PaperMetadata( + id=f"s2_{item.paperId}", + title=item.title or "Unknown", + authors=[a.name for a in (item.authors or [])], + year=item.year, + doi=item.externalIds.get("DOI") if item.externalIds else None, + abstract=item.abstract, + venue=getattr(item, "venue", None), + citation_count=getattr(item, "citationCount", None), + is_open_access=bool(getattr(item, "isOpenAccess", False)), + source=PaperSource.SEMANTIC_SCHOLAR, + url=item.url, + landing_url=item.url, + pdf_url=pdf_url, + ) + papers.append(paper) + time.sleep(self.delay) # Rate limit between items + + logger.info(f"Semantic Scholar returned {len(papers)} papers") + return papers + + except Exception as e: + logger.warning(f"Semantic Scholar search failed (likely 403): {e}") + return [] + + +class PaperDiscoveryAgent: + """ + Paper discovery agent. + Aggregates multiple search sources, deduplicates, and sorts results. + """ + + def __init__(self) -> None: + config = get_config() + self.searchers: List[tuple] = [] + + if config.arxiv_enabled: + self.searchers.append(("arxiv", ArxivSearcher())) + if config.pubmed_enabled: + self.searchers.append(("pubmed", PubMedSearcher())) + if config.semantic_scholar_enabled: + self.searchers.append(("semantic_scholar", SemanticScholarSearcher())) + + logger.info(f"Initialized PaperDiscoveryAgent with sources: {[s[0] for s in self.searchers]}") + + def discover( + self, + query: str, + limit_per_source: int = 10, + deduplicate: bool = True + ) -> List[PaperMetadata]: + """ + Search all sources and aggregate results. + + Args: + query: Search query + limit_per_source: Maximum results per source + deduplicate: Whether to deduplicate by title + + Returns: + Aggregated list of papers + """ + all_papers: List[PaperMetadata] = [] + + for source_name, searcher in self.searchers: + try: + papers = searcher.search(query, limit_per_source) + all_papers.extend(papers) + logger.info(f"{source_name} returned {len(papers)} papers") + except Exception as e: + logger.error(f"Search failed for {source_name}: {e}") + + logger.info(f"Total papers before deduplication: {len(all_papers)}") + + if deduplicate: + all_papers = self._deduplicate(all_papers) + logger.info(f"Total papers after deduplication: {len(all_papers)}") + + return all_papers + + def _deduplicate(self, papers: List[PaperMetadata]) -> List[PaperMetadata]: + """Deduplicate papers by normalized title.""" + seen_titles: set = set() + unique_papers: List[PaperMetadata] = [] + + for paper in papers: + # Normalize title for comparison + normalized = paper.title.lower().strip() + if normalized not in seen_titles: + seen_titles.add(normalized) + unique_papers.append(paper) + + return unique_papers + + def build_thermoelectric_query( + self, + polymer: Optional[str] = None, + include_tc: bool = True + ) -> str: + """ + Build a specialized thermoelectric search query. + + Args: + polymer: Specific polymer name (e.g., "P3HT") + include_tc: Whether to include thermal conductivity keywords + + Returns: + Optimized search query string + """ + base_terms = [ + "organic thermoelectric", + "conjugated polymer", + "electrical conductivity", + ] + + if include_tc: + base_terms.append("thermal conductivity") + + if polymer: + base_terms.insert(0, polymer) + + query = " ".join(base_terms) + logger.debug(f"Built query: {query}") + return query diff --git a/literature/evaluation.py b/literature/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..b686db473a7bb107d97ad512d62ccaa3563ee3e0 --- /dev/null +++ b/literature/evaluation.py @@ -0,0 +1,155 @@ +""" +Offline evaluation helpers for structured literature extraction. + +The harness is intentionally dataset-agnostic so POLYIE-formatted exports and +internal regression sets can share the same metric implementation. +""" +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence, Tuple + +from .property_registry import normalize_property_key + + +CORE_FIELDS = ["material_name", "property_key", "raw_value", "raw_unit", "method"] + + +def load_json_records(path: str | Path) -> List[Dict[str, Any]]: + fp = Path(path) + if fp.suffix == ".jsonl": + return [json.loads(line) for line in fp.read_text(encoding="utf-8").splitlines() if line.strip()] + data = json.loads(fp.read_text(encoding="utf-8")) + if isinstance(data, list): + return data + raise ValueError(f"Unsupported evaluation file format: {fp}") + + +def normalize_record(record: Dict[str, Any]) -> Dict[str, Any]: + material = str( + record.get("material_name") + or record.get("polymer_name") + or record.get("material") + or "" + ).strip() + property_key = normalize_property_key( + str(record.get("property_key") or record.get("property_name") or "") + ) or str(record.get("property_key") or record.get("property_name") or "").strip() + raw_value = str(record.get("raw_value") or record.get("value") or "").strip() + raw_unit = str(record.get("raw_unit") or record.get("unit") or "").strip() + method = str(record.get("method") or record.get("measurement_method") or "").strip() + evidence_quote = str(record.get("evidence_quote") or record.get("source_quote") or "").strip() + return { + "material_name": material, + "property_key": property_key, + "raw_value": raw_value, + "raw_unit": raw_unit, + "method": method, + "evidence_quote": evidence_quote, + } + + +def _safe_div(numerator: float, denominator: float) -> float: + return numerator / denominator if denominator else 0.0 + + +def _f1(precision: float, recall: float) -> float: + return (2 * precision * recall) / (precision + recall) if (precision + recall) else 0.0 + + +def _field_pairs(records: Sequence[Dict[str, Any]], field: str) -> set[Tuple[str, str]]: + pairs = set() + for record in records: + normalized = normalize_record(record) + key = normalized.get("material_name", "") + value = normalized.get(field, "") + if key and value: + pairs.add((key.lower(), value.lower())) + return pairs + + +def _relation_tuples(records: Sequence[Dict[str, Any]]) -> set[Tuple[str, str, str]]: + triples = set() + for record in records: + normalized = normalize_record(record) + if normalized["material_name"] and normalized["property_key"] and normalized["raw_value"]: + triples.add( + ( + normalized["material_name"].lower(), + normalized["property_key"].lower(), + normalized["raw_value"].lower(), + ) + ) + return triples + + +def _record_tuples(records: Sequence[Dict[str, Any]]) -> set[Tuple[str, str, str, str, str]]: + tuples = set() + for record in records: + normalized = normalize_record(record) + tuples.add( + tuple(normalized[field].lower() for field in CORE_FIELDS) + ) + return tuples + + +def evaluate_predictions( + gold_records: Sequence[Dict[str, Any]], + predicted_records: Sequence[Dict[str, Any]], +) -> Dict[str, Any]: + gold = [normalize_record(record) for record in gold_records] + predicted = [normalize_record(record) for record in predicted_records] + + field_metrics: Dict[str, Dict[str, float]] = {} + for field in CORE_FIELDS: + gold_pairs = _field_pairs(gold, field) + predicted_pairs = _field_pairs(predicted, field) + tp = len(gold_pairs & predicted_pairs) + precision = _safe_div(tp, len(predicted_pairs)) + recall = _safe_div(tp, len(gold_pairs)) + field_metrics[field] = { + "precision": precision, + "recall": recall, + "f1": _f1(precision, recall), + } + + gold_rel = _relation_tuples(gold) + pred_rel = _relation_tuples(predicted) + rel_tp = len(gold_rel & pred_rel) + rel_precision = _safe_div(rel_tp, len(pred_rel)) + rel_recall = _safe_div(rel_tp, len(gold_rel)) + + gold_records_set = _record_tuples(gold) + pred_records_set = _record_tuples(predicted) + record_tp = len(gold_records_set & pred_records_set) + record_precision = _safe_div(record_tp, len(pred_records_set)) + record_recall = _safe_div(record_tp, len(gold_records_set)) + + filled_fields = [ + sum(1 for field in CORE_FIELDS if record.get(field)) + for record in predicted + ] + record_completeness = _safe_div(sum(filled_fields), len(predicted) * len(CORE_FIELDS)) + source_grounding_hit_rate = _safe_div( + sum(1 for record in predicted if record.get("evidence_quote")), + len(predicted), + ) + + return { + "field_metrics": field_metrics, + "relation_level": { + "precision": rel_precision, + "recall": rel_recall, + "f1": _f1(rel_precision, rel_recall), + }, + "record_level": { + "precision": record_precision, + "recall": record_recall, + "f1": _f1(record_precision, record_recall), + }, + "record_completeness": record_completeness, + "source_grounding_hit_rate": source_grounding_hit_rate, + "gold_count": len(gold), + "predicted_count": len(predicted), + } diff --git a/literature/extraction.py b/literature/extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..ee657ed4a19df24cfc8508351e2877ed5c437cb2 --- /dev/null +++ b/literature/extraction.py @@ -0,0 +1,863 @@ +""" +LLM-based structured data extraction module. +Implements flexible interface: PageIndex (RAG via indexed PDFs) or Simple extraction (fallback). + +Prompts are dynamically built from user-selected target properties via +``literature.property_registry.build_extraction_prompt``. +""" +import json +import re +import logging +import os +from typing import List, Optional, Any +from datetime import datetime + +from .schemas import ( + PaperMetadata, + PolymerDataPoint, + ExtractionResult, + DataQuality +) +from .config import get_config +from .retrieval import extract_text_from_pdf +from .property_registry import PROPERTY_CATALOG, build_extraction_prompt, TEMPLATES + +logger = logging.getLogger(__name__) + +# Default property set used when no explicit target properties are provided. +# The legacy thermoelectric-only template no longer exists in the production +# registry, so fall back to the platform-wide property core. +_DEFAULT_PROPERTIES = TEMPLATES.get("platform_core") or list(PROPERTY_CATALOG.keys()) + +_SKIP_ERROR_MESSAGES = { + "llm_unconfigured", + "contextual_llm_unconfigured", + "extraction_backend_unconfigured", + "pageindex_requires_pdf_no_simple_backend", + "pageindex_sdk_unavailable_no_simple_backend", +} + + +def _normalize_base_url(url: Optional[str]) -> Optional[str]: + text = str(url or "").strip().rstrip("/") + return text or None + + +def _is_http_url(url: Optional[str]) -> bool: + text = _normalize_base_url(url) + return bool(text and (text.startswith("http://") or text.startswith("https://"))) + + +def is_expected_skip_error(error_message: Optional[str]) -> bool: + return str(error_message or "").strip() in _SKIP_ERROR_MESSAGES + + +# ============== JSON Safe Parsing (Fix Logic Bug #4 & #5) ============== + +def normalize_minus_signs(s: str) -> str: + """ + Normalize all types of minus signs to ASCII minus. + + Fixes Logic Bug #5: OCR may produce Unicode minus (U+2212) instead of ASCII. + """ + minus_chars = [ + '−', # U+2212 MINUS SIGN + '–', # U+2013 EN DASH + '—', # U+2014 EM DASH + '‐', # U+2010 HYPHEN + '‑', # U+2011 NON-BREAKING HYPHEN + '‒', # U+2012 FIGURE DASH + '⁻', # U+207B SUPERSCRIPT MINUS + '₋', # U+208B SUBSCRIPT MINUS + ] + for char in minus_chars: + s = s.replace(char, '-') + return s + + +def safe_json_loads(text: str) -> Any: + """ + Safely parse JSON, handling common LLM output issues. + + Fixes Logic Bug #4: LLM may return NaN, Infinity, Python-style None, trailing commas. + """ + if not text: + return None + + text = text.strip() + + # Extract JSON from markdown code blocks + if "```json" in text: + text = text.split("```json")[1].split("```")[0] + elif "```" in text: + parts = text.split("```") + if len(parts) >= 2: + text = parts[1] + + # Normalize minus signs + text = normalize_minus_signs(text) + + # Fix Python-style -> JSON-style + text = re.sub(r'\bNone\b', 'null', text) + text = re.sub(r'\bTrue\b', 'true', text) + text = re.sub(r'\bFalse\b', 'false', text) + + # Remove trailing commas + text = re.sub(r',\s*}', '}', text) + text = re.sub(r',\s*]', ']', text) + + # Handle NaN and Infinity + text = re.sub(r'\bNaN\b', 'null', text) + text = re.sub(r'\bInfinity\b', 'null', text) + text = re.sub(r'-Infinity\b', 'null', text) + + try: + return json.loads(text) + except json.JSONDecodeError as e: + logger.warning(f"Initial JSON parse failed: {e}") + + # Try json_repair if available + try: + from json_repair import repair_json + repaired = repair_json(text) + return json.loads(repaired) + except ImportError: + logger.warning("json_repair not installed, cannot repair JSON") + raise + except Exception as e2: + logger.error(f"JSON repair also failed: {e2}") + raise + + + + +# Extraction prompt template +EXTRACTION_PROMPT = """ +You are an expert in organic thermoelectrics and polymer science. +Your task is to extract ALL experimental data points from the provided paper. + +## Target Data +Extract data for conjugated polymers used in thermoelectric applications, including: +- PEDOT:PSS, P3HT, PBTTT, P(NDI2OD-T2), PDPP series, etc. + +## Required Fields (extract as many as available) +For EACH data point, extract: + +### Material Information +- polymer_name: The polymer name/abbreviation (e.g., "P3HT", "PEDOT:PSS") +- dopant: Dopant used (e.g., "DMSO", "H2SO4", "FeCl3") +- dopant_ratio: Dopant concentration if specified (e.g., "5 wt%", "1 M") + +### Processing Conditions +- solvent: Solvent used for film preparation +- concentration_mg_ml: Solution concentration in mg/mL +- spin_speed_rpm: Spin coating speed in RPM +- spin_time_s: Spin coating time in seconds +- annealing_temp_c: Annealing temperature in Celsius +- annealing_time_min: Annealing time in minutes +- annealing_atmosphere: Atmosphere during annealing (N2, Air, Vacuum) +- film_thickness_nm: Film thickness in nanometers + +### Electrical Properties +- electrical_conductivity_s_cm: Electrical conductivity in S/cm +- seebeck_coefficient_uv_k: Seebeck coefficient in μV/K +- power_factor_uw_m_k2: Power factor in μW/(m·K²) + +### Thermal Properties (IMPORTANT - often sparse) +- thermal_conductivity_w_mk: Thermal conductivity in W/(m·K) +- zt_figure_of_merit: ZT figure of merit (dimensionless) + +### Structural Characterization +- xrd_crystallinity_percent: Crystallinity percentage from XRD +- xrd_pi_stacking_angstrom: π-π stacking distance in Angstrom +- xrd_lamellar_spacing_angstrom: Lamellar spacing in Angstrom + +### Metadata +- source_table_or_figure: Where the data was found (e.g., "Table 1", "Figure 3") +- extraction_confidence: Your confidence in this extraction (0.0 to 1.0) + +## CRITICAL Rules +1. Extract ONLY experimentally measured values, not theoretical predictions +2. Convert all units to the specified standard units +3. If a value range is given (e.g., "100-200 S/cm"), use the AVERAGE +4. If a value is "not measured" or "N/A", use null +5. Each row in a table = one data point +6. Include the source_table_or_figure for traceability + +## Output Format +Return a valid JSON array. Example: +[ + { + "polymer_name": "PEDOT:PSS", + "dopant": "H2SO4", + "dopant_ratio": "5 vol%", + "electrical_conductivity_s_cm": 1200.5, + "thermal_conductivity_w_mk": 0.35, + "source_table_or_figure": "Table 2", + "extraction_confidence": 0.9 + } +] + +Return ONLY the JSON array, no markdown formatting, no explanations. +If no relevant data is found, return an empty array: [] +""" + + +class DataExtractor: + """ + Flexible data extractor with fallback strategy. + + Primary: PageIndex (RAG via indexed PDFs) + Fallback: Simple extraction (pymupdf + direct LLM) + """ + + def __init__( + self, + strategy: Optional[str] = None, + target_properties: Optional[List[str]] = None, + extra_instructions: str = "", + ) -> None: + config = get_config() + self.strategy = strategy or config.extraction_strategy + self.llm_model = config.llm_model + self.gemini_key = config.gemini_api_key + self.openai_key = config.openai_api_key + self.openai_base_url = _normalize_base_url(config.openai_base_url) + self.pdf_dir = config.pdf_storage_dir + self.pageindex_api_key = config.pageindex_api_key + self.target_properties = target_properties or _DEFAULT_PROPERTIES + self.extra_instructions = extra_instructions + + logger.info(f"Initialized DataExtractor with strategy: {self.strategy}, properties: {self.target_properties}") + + def has_openai_backend(self) -> bool: + return _is_http_url(self.openai_base_url) + + def has_any_llm_backend(self) -> bool: + return self.has_openai_backend() or bool(str(self.gemini_key or "").strip()) + + def has_pageindex_backend(self) -> bool: + return bool(str(self.pageindex_api_key or "").strip()) + + def can_attempt_extraction(self) -> bool: + return self.has_pageindex_backend() or self.has_any_llm_backend() + + def availability_reason(self) -> Optional[str]: + if self.can_attempt_extraction(): + return None + return "Structured extraction skipped: configure PAGEINDEX_API_KEY or a valid LLM backend." + + def extract_from_papers( + self, + papers: List[PaperMetadata], + use_full_text: bool = True + ) -> List[ExtractionResult]: + """ + Extract data from multiple papers. + + Args: + papers: List of paper metadata (with pdf_path if available) + use_full_text: Use PDF full text if available + + Returns: + List of extraction results + """ + results: List[ExtractionResult] = [] + + for paper in papers: + try: + if self.strategy == "pageindex": + result = self._extract_with_pageindex(paper) + else: + result = self._extract_simple(paper, use_full_text) + results.append(result) + except Exception as e: + logger.error(f"Extraction failed for {paper.id}: {e}") + results.append(ExtractionResult( + paper=paper, + success=False, + error_message=str(e) + )) + + return results + + def _extract_simple( + self, + paper: PaperMetadata, + use_full_text: bool = True + ) -> ExtractionResult: + """ + Simple extraction: Extract PDF text -> Feed to LLM -> Parse JSON. + Often more effective for metadata extraction. + """ + logger.info(f"Simple extraction for: {paper.title[:50]}...") + + # Get content + content = self._prepare_content(paper, use_full_text) + if not content: + return ExtractionResult( + paper=paper, + success=False, + error_message="No content available" + ) + + if not self.has_any_llm_backend(): + return ExtractionResult( + paper=paper, + success=False, + error_message="llm_unconfigured", + extraction_notes="Simple extraction skipped because no LLM backend is configured.", + ) + + # Call LLM with dynamic prompt + dynamic_prompt = build_extraction_prompt(self.target_properties, self.extra_instructions) + prompt = dynamic_prompt.replace("{title}", paper.title or "Unknown").replace("{content}", content) + + try: + raw_response = self._call_llm(prompt) + + if not raw_response: + return ExtractionResult( + paper=paper, + success=False, + error_message="LLM returned empty response" + ) + + # Parse response + data_points = self._parse_llm_output(raw_response, paper.id) + + # Assess quality for each point + for dp in data_points: + dp.quality_tier = self._assess_quality(dp) + + return ExtractionResult( + paper=paper, + data_points=data_points, + llm_model_used=self.llm_model, + extraction_timestamp=datetime.now(), + success=True + ) + + except Exception as e: + logger.error(f"Simple extraction failed: {e}") + return ExtractionResult( + paper=paper, + success=False, + error_message=str(e) + ) + + def _extract_with_pageindex(self, paper: PaperMetadata) -> ExtractionResult: + """ + PageIndex extraction (RAG-enhanced via indexed PDF). + Submits PDF to PageIndex, then uses chat_completions with extraction prompt. + Falls back to simple extraction if PageIndex is unavailable or fails. + """ + if not self.has_pageindex_backend(): + if self.has_any_llm_backend(): + return self._extract_simple(paper) + return ExtractionResult( + paper=paper, + success=False, + error_message="extraction_backend_unconfigured", + extraction_notes="No PageIndex or LLM backend is configured.", + ) + + if not paper.pdf_path or not os.path.exists(paper.pdf_path): + if self.has_any_llm_backend(): + return self._extract_simple(paper) + return ExtractionResult( + paper=paper, + success=False, + error_message="pageindex_requires_pdf_no_simple_backend", + extraction_notes="PageIndex extraction requires a PDF when no simple LLM fallback is available.", + ) + + try: + from src.literature_service.pageindex_client import PageIndexService + except ImportError: + if self.has_any_llm_backend(): + return self._extract_simple(paper) + return ExtractionResult( + paper=paper, + success=False, + error_message="pageindex_sdk_unavailable_no_simple_backend", + extraction_notes="PageIndex SDK unavailable and no simple LLM fallback is configured.", + ) + + logger.info(f"PageIndex extraction for: {paper.title[:50]}...") + + try: + service = PageIndexService(api_key=self.pageindex_api_key) + + # Submit the document to PageIndex + doc_id = service.submit_document(paper.pdf_path) + logger.info(f"Submitted to PageIndex, doc_id={doc_id}") + + # Wait for indexing to complete (poll status) + import time + for _ in range(30): # max ~60 seconds + status = service.get_document_status(doc_id) + if status == "completed": + break + if status in ("error", "failed"): + raise RuntimeError(f"PageIndex indexing failed with status: {status}") + time.sleep(2) + else: + logger.warning("PageIndex indexing timed out, falling back to simple") + return self._extract_simple(paper) + + # Use chat_completions with dynamic extraction prompt + dynamic_prompt = build_extraction_prompt(self.target_properties, self.extra_instructions) + # For PageIndex chat, we don't need the {title}/{content} placeholders + # since the document is already indexed; strip those sections. + pi_prompt = dynamic_prompt.split("**PAPER CONTENT:**")[0].strip() + raw_answer = service.chat_completions(pi_prompt, doc_id) + + if not raw_answer: + return ExtractionResult( + paper=paper, + success=False, + error_message="PageIndex returned empty response" + ) + + # Parse result + data_points = self._parse_llm_output(raw_answer, paper.id) + + for dp in data_points: + dp.quality_tier = self._assess_quality(dp) + + return ExtractionResult( + paper=paper, + data_points=data_points, + llm_model_used="pageindex", + extraction_timestamp=datetime.now(), + success=True + ) + + except Exception as e: + logger.warning(f"PageIndex extraction failed, falling back to simple: {e}") + return self._extract_simple(paper) + + def _prepare_content( + self, + paper: PaperMetadata, + use_full_text: bool = True + ) -> Optional[str]: + """Prepare text content for extraction.""" + # Try PDF full text first + if use_full_text and paper.pdf_path and os.path.exists(paper.pdf_path): + full_text = extract_text_from_pdf(paper.pdf_path, max_pages=5) + if full_text: + return f"Title: {paper.title}\n\n{full_text}" + + # Fallback to abstract + if paper.abstract: + return f"Title: {paper.title}\n\nAbstract:\n{paper.abstract}" + + # Just title + if paper.title: + return f"Title: {paper.title}" + + return None + + def _call_llm(self, prompt: str) -> Optional[str]: + """ + Call LLM (OpenAI-compatible first, then Gemini fallback). + Prioritizes CRC OpenWebUI for reliability. + """ + # Try OpenAI-compatible (CRC) first + if self.openai_key and self.openai_base_url: + try: + logger.info(f"Calling CRC OpenWebUI...") + return self._call_openai_compatible(prompt) + except Exception as e: + logger.warning(f"CRC OpenWebUI call failed: {e}") + + # Fallback to Gemini + if self.gemini_key: + try: + logger.info("Falling back to Gemini...") + return self._call_gemini(prompt) + except Exception as e: + logger.warning(f"Gemini call failed: {e}") + + logger.debug("No LLM backend configured; skipping simple extraction call.") + return None + + def _call_gemini(self, prompt: str) -> str: + """Call Gemini API.""" + import google.generativeai as genai + + genai.configure(api_key=self.gemini_key) + model = genai.GenerativeModel("gemini-2.0-flash") + + response = model.generate_content(prompt) + return response.text + + def _call_openai_compatible(self, prompt: str) -> str: + """Call OpenAI-compatible API (CRC OpenWebUI).""" + from openai import OpenAI + + client = OpenAI( + api_key=self.openai_key, + base_url=self.openai_base_url + ) + + # Use model from config (set in .env LLM_MODEL) + model = self.llm_model + # Handle litellm-style prefixes + if model.startswith("gemini/"): + model = "gpt-oss:latest" # Fallback for CRC + logger.info(f"Using model: {model}") + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1 + ) + + return response.choices[0].message.content + + def _parse_llm_output( + self, + raw_output: str, + paper_id: str + ) -> List[PolymerDataPoint]: + """Parse LLM output into structured data points.""" + try: + # Use safe_json_loads for robust parsing + raw_data = safe_json_loads(raw_output) + except Exception as e: + logger.error(f"JSON parsing failed for {paper_id}: {e}") + return [] + + if raw_data is None: + logger.warning(f"No JSON data found in output for {paper_id}") + return [] + + # Ensure it's a list + if not isinstance(raw_data, list): + raw_data = [raw_data] + + # Convert to Pydantic models + data_points: List[PolymerDataPoint] = [] + for item in raw_data: + try: + dp = PolymerDataPoint( + polymer_name=item.get("polymer_name", "Unknown"), + dopant=item.get("dopant"), + dopant_ratio=item.get("dopant_ratio"), + solvent=item.get("solvent"), + concentration_mg_ml=item.get("concentration_mg_ml"), + spin_speed_rpm=item.get("spin_speed_rpm"), + spin_time_s=item.get("spin_time_s"), + annealing_temp_c=item.get("annealing_temp_c"), + annealing_time_min=item.get("annealing_time_min"), + annealing_atmosphere=item.get("annealing_atmosphere"), + film_thickness_nm=item.get("film_thickness_nm"), + electrical_conductivity_s_cm=item.get("electrical_conductivity_s_cm"), + seebeck_coefficient_uv_k=item.get("seebeck_coefficient_uv_k"), + power_factor_uw_m_k2=item.get("power_factor_uw_m_k2"), + thermal_conductivity_w_mk=item.get("thermal_conductivity_w_mk"), + zt_figure_of_merit=item.get("zt_figure_of_merit"), + xrd_crystallinity_percent=item.get("xrd_crystallinity_percent"), + xrd_pi_stacking_angstrom=item.get("xrd_pi_stacking_angstrom"), + xrd_lamellar_spacing_angstrom=item.get("xrd_lamellar_spacing_angstrom"), + source_paper_id=paper_id, + source_table_or_figure=item.get("source_table_or_figure"), + extraction_confidence=item.get("extraction_confidence", 0.5), + ) + data_points.append(dp) + except Exception as e: + logger.warning(f"Failed to parse data point: {e}") + + logger.info(f"Extracted {len(data_points)} data points from {paper_id}") + return data_points + + def _assess_quality(self, dp: PolymerDataPoint) -> DataQuality: + """Assess data point quality tier.""" + has_ec = dp.electrical_conductivity_s_cm is not None + has_tc = dp.thermal_conductivity_w_mk is not None + has_xrd = (dp.xrd_crystallinity_percent is not None or + dp.xrd_pi_stacking_angstrom is not None) + has_process = (dp.annealing_temp_c is not None and + dp.spin_speed_rpm is not None) + + if has_ec and has_tc and has_xrd and has_process: + return DataQuality.GOLD + elif has_ec and (has_xrd or has_process): + return DataQuality.SILVER + else: + return DataQuality.BRONZE + + +# ============== NEW: Contextualized Extraction ============== + +CONTEXTUALIZED_EXTRACTION_PROMPT = """ +You are an expert in organic thermoelectrics and polymer science. +Extract ALL experimental data points from the provided paper. + +## CRITICAL REQUIREMENTS + +1. **Extract ALL values, not just the best one** + - A paper may report multiple values under different conditions + - Extract EACH value as a separate data point + +2. **Include COMPLETE experimental conditions** + - Every value must have its associated conditions + - Common: temperature, annealing, doping level, measurement method + +3. **MANDATORY: Include source quote** + - For EACH data point, include the exact sentence from the paper + - Quote must be >10 characters and reference the value + +## TARGET PROPERTIES + +- `electrical_conductivity` (S/cm, S/m) +- `thermal_conductivity` (W/mK) +- `seebeck_coefficient` (μV/K) +- `power_factor` (μW/mK²) +- `zt_figure_of_merit` (dimensionless) + +## OUTPUT FORMAT (JSON Array) + +Return ONLY valid JSON, no markdown, no explanation: + +[ + {{ + "polymer_name": "PEDOT:PSS", + "dopant": "H2SO4", + "dopant_ratio": "5 vol%", + "property_name": "electrical_conductivity", + "raw_value": "4380", + "raw_unit": "S/cm", + "conditions": {{ + "solvent": "water", + "annealing_temp_c": 150, + "annealing_time_min": 10, + "measurement_temp_k": 300, + "measurement_method": "4-point probe" + }}, + "source_quote": "The electrical conductivity reached 4380 S/cm after H2SO4 treatment.", + "source_location": "Table 2, Sample S5", + "extraction_confidence": 0.95 + }} +] + +## RULES + +1. If values range "from X to Y", extract BOTH as separate points +2. Preserve scientific notation as "5.2e3" or actual number +3. If no source quote found, set extraction_confidence < 0.5 +4. Return ONLY valid JSON array, no other text + +--- + +**PAPER CONTENT:** + +Title: {title} + +{content} + +--- + +JSON output: +""" + + +class ContextualizedExtractor: + """ + Contextualized data extractor. + + Produces ContextualizedValue objects with mandatory source quotes for traceability. + """ + + def __init__( + self, + model_id: str = None, + target_properties: Optional[List[str]] = None, + extra_instructions: str = "", + ): + """ + Initialize extractor. + + Args: + model_id: LLM model ID to use (default from config) + target_properties: List of property keys to extract + extra_instructions: Free-form LLM instructions + """ + config = get_config() + self.model_id = model_id or config.llm_model + self.openai_base_url = _normalize_base_url(config.openai_base_url) + self.openai_key = config.openai_api_key + self.target_properties = target_properties or _DEFAULT_PROPERTIES + self.extra_instructions = extra_instructions + + def is_configured(self) -> bool: + return _is_http_url(self.openai_base_url) + + def extract_from_paper( + self, + paper: PaperMetadata, + use_full_text: bool = True + ) -> "ExtractionResult": + """ + Extract contextualized data from a paper. + + Args: + paper: Paper metadata + use_full_text: Use PDF full text if available + + Returns: + ExtractionResult with ContextualizedValue data points + """ + from .schemas import ContextualizedValue, ExperimentalConditions, ExtractionResult + + logger.info(f"Contextualized extraction for: {paper.title[:50]}...") + + if not self.is_configured(): + return ExtractionResult( + paper_id=paper.id, + paper_title=paper.title, + success=False, + error_message="contextual_llm_unconfigured", + extraction_notes="Contextualized extraction skipped because no OpenAI-compatible base URL is configured.", + ) + + # Prepare content + content = paper.full_text if use_full_text and paper.full_text else paper.abstract + if not content: + return ExtractionResult( + paper_id=paper.id, + paper_title=paper.title, + success=False, + error_message="No content available" + ) + + # Truncate content to fit context window + content = content[:15000] + + # Build dynamic prompt from target properties + prompt_template = build_extraction_prompt(self.target_properties, self.extra_instructions) + prompt = prompt_template.replace("{title}", paper.title or "Unknown").replace("{content}", content) + + try: + # Call LLM + raw_response = self._call_llm(prompt) + + if not raw_response: + return ExtractionResult( + paper_id=paper.id, + paper_title=paper.title, + success=False, + error_message="contextual_llm_unconfigured" if not self.is_configured() else "LLM returned empty response" + ) + + # Parse response + data_points = self._parse_response(raw_response, paper.id) + + return ExtractionResult( + paper_id=paper.id, + paper_title=paper.title, + data_points=data_points, + extraction_model=self.model_id, + success=True + ) + + except Exception as e: + logger.warning(f"Contextualized extraction failed for {paper.id}: {e}") + return ExtractionResult( + paper_id=paper.id, + paper_title=paper.title, + success=False, + error_message=str(e) + ) + + def _call_llm(self, prompt: str) -> Optional[str]: + """Call LLM via OpenAI-compatible API.""" + import httpx + + if not self.is_configured(): + logger.debug("Contextualized extractor skipped: OpenAI-compatible base URL is not configured.") + return None + + logger.info("Calling LLM for contextualized extraction...") + logger.info(f"Using model: {self.model_id}") + + headers = { + "Content-Type": "application/json", + } + if self.openai_key: + headers["Authorization"] = f"Bearer {self.openai_key}" + + payload = { + "model": self.model_id, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.2, + "max_tokens": 3000, + } + + with httpx.Client(timeout=120) as client: + response = client.post( + f"{self.openai_base_url}/chat/completions", + json=payload, + headers=headers + ) + response.raise_for_status() + data = response.json() + + return data["choices"][0]["message"]["content"] + + def _parse_response(self, response: str, paper_id: str) -> List: + """Parse LLM response into ContextualizedValue objects.""" + from .schemas import ContextualizedValue, ExperimentalConditions + + try: + data = safe_json_loads(response) + except Exception as e: + logger.warning(f"JSON parse failed for {paper_id}: {e}") + return [] + + if data is None: + return [] + + if not isinstance(data, list): + data = [data] + + results = [] + for item in data: + if not isinstance(item, dict): + continue + + try: + # Handle conditions + conditions_data = item.pop("conditions", {}) + conditions = ExperimentalConditions(**conditions_data) if conditions_data else ExperimentalConditions() + + # Ensure required fields + if "source_quote" not in item or not item.get("source_quote"): + item["source_quote"] = f"[Extracted from {paper_id}]" + + value = ContextualizedValue( + conditions=conditions, + **item + ) + results.append(value) + except Exception as e: + logger.warning(f"Failed to parse data point: {e}") + continue + + logger.info(f"Extracted {len(results)} contextualized data points from {paper_id}") + return results + + def extract_from_papers( + self, + papers: List[PaperMetadata], + use_full_text: bool = True + ) -> List: + """Batch extraction from multiple papers.""" + results = [] + for paper in papers: + result = self.extract_from_paper(paper, use_full_text) + results.append(result) + return results diff --git a/literature/graph.py b/literature/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..1a820e1d74c1788d4b4e906c5cc87a76759b831f --- /dev/null +++ b/literature/graph.py @@ -0,0 +1,450 @@ +""" +LangGraph workflow for Literature Discovery System. +Implements: discover → download → extract → quality pipeline. + +Key design principles: +1. All state modifications must be explicit in return values +2. No in-place object modification +3. Each node returns logs for UI feedback +""" +import logging +from typing import TypedDict, List, Optional, Annotated, Literal, Callable, Any +from datetime import datetime +import operator + +from langgraph.graph import StateGraph, END, START +from langgraph.checkpoint.memory import MemorySaver + +from .schemas import PaperMetadata, PolymerDataPoint, ExtractionResult, DataQuality +from .discovery import PaperDiscoveryAgent +from .extraction import DataExtractor +from .quality import QualityAssessor + +logger = logging.getLogger(__name__) + + +# ============== State Definition ============== + +class LogEntry(TypedDict): + """Log entry for UI feedback""" + timestamp: str + node: str + message: str + level: str # info, warning, error + + +class LiteratureState(TypedDict): + """ + Workflow state. + + Important: LangGraph state updates are based on return values. + If you modify a field, you MUST include it in the return dict. + """ + # Input + search_query: str + max_papers: int + use_full_text: bool + + # Progress tracking + current_node: str + progress_percent: int + + # Intermediate results + papers: List[Any] # List[PaperMetadata] serialized + downloaded_pdfs: List[str] + extraction_results: List[Any] # List[ExtractionResult] serialized + + # Final output + verified_data: List[Any] # List[PolymerDataPoint] serialized + quality_report: Optional[dict] + + # Logging & Status + logs: Annotated[List[LogEntry], operator.add] + status: Literal["running", "completed", "failed", "cancelled"] + error: Optional[str] + + +def create_initial_state( + query: str, + max_papers: int = 10, + use_full_text: bool = False +) -> LiteratureState: + """Create initial state""" + return { + "search_query": query, + "max_papers": max_papers, + "use_full_text": use_full_text, + "current_node": "start", + "progress_percent": 0, + "papers": [], + "downloaded_pdfs": [], + "extraction_results": [], + "verified_data": [], + "quality_report": None, + "logs": [], + "status": "running", + "error": None, + } + + +# ============== Helper Functions ============== + +def _log(node: str, message: str, level: str = "info") -> LogEntry: + """Create log entry""" + return { + "timestamp": datetime.now().isoformat(), + "node": node, + "message": message, + "level": level, + } + + +def _serialize_paper(paper: PaperMetadata) -> dict: + """Serialize paper for state storage""" + return paper.model_dump() + + +def _deserialize_paper(data: dict) -> PaperMetadata: + """Deserialize paper from state""" + return PaperMetadata(**data) + + +# ============== Node Functions ============== + +def discover_node(state: LiteratureState) -> dict: + """ + Paper discovery node. + Uses existing PaperDiscoveryAgent (synchronous). + """ + node_name = "discover" + logs = [_log(node_name, f"Searching for: '{state['search_query']}'")] + + try: + agent = PaperDiscoveryAgent() + papers = agent.discover( + query=state["search_query"], + limit_per_source=state["max_papers"], + ) + + logs.append(_log(node_name, f"Found {len(papers)} unique papers")) + + # Serialize papers for state storage + serialized_papers = [_serialize_paper(p) for p in papers] + + return { + "papers": serialized_papers, + "current_node": node_name, + "progress_percent": 25, + "logs": logs, + } + except Exception as e: + logger.exception(f"Discover node failed: {e}") + logs.append(_log(node_name, f"Error: {e}", "error")) + return { + "papers": [], + "current_node": node_name, + "status": "failed", + "error": str(e), + "logs": logs, + } + + +def download_node(state: LiteratureState) -> dict: + """ + PDF download node. + Uses existing PDFRetriever (synchronous). + """ + from .retrieval import PDFRetriever + + node_name = "download" + paper_dicts = state["papers"] + logs = [_log(node_name, f"Downloading content for {len(paper_dicts)} papers")] + + try: + # Deserialize papers + papers = [_deserialize_paper(p) for p in paper_dicts] + + retriever = PDFRetriever() + papers = retriever.retrieve_batch(papers) + + # Count successes + downloaded = [p for p in papers if p.pdf_path] + logs.append(_log(node_name, f"Downloaded {len(downloaded)}/{len(papers)} PDFs")) + + # Re-serialize papers with updated pdf_path + serialized_papers = [_serialize_paper(p) for p in papers] + downloaded_pdfs = [p.pdf_path for p in downloaded if p.pdf_path] + + return { + "papers": serialized_papers, + "downloaded_pdfs": downloaded_pdfs, + "current_node": node_name, + "progress_percent": 50, + "logs": logs, + } + except Exception as e: + logger.exception(f"Download node failed: {e}") + logs.append(_log(node_name, f"Error: {e}", "error")) + return { + "downloaded_pdfs": [], + "current_node": node_name, + "status": "failed", + "error": str(e), + "logs": logs, + } + + +def extract_node(state: LiteratureState) -> dict: + """ + Data extraction node. + Uses existing DataExtractor (synchronous). + """ + node_name = "extract" + logs = [_log(node_name, "Extracting structured data from papers")] + + try: + # Deserialize papers + papers = [_deserialize_paper(p) for p in state["papers"]] + + # Filter papers with content + papers_with_content = [p for p in papers if p.pdf_path or p.abstract] + + if not papers_with_content: + logs.append(_log(node_name, "No papers with content to extract", "warning")) + return { + "extraction_results": [], + "current_node": node_name, + "progress_percent": 75, + "logs": logs, + } + + logs.append(_log(node_name, f"Processing {len(papers_with_content)} papers with content")) + + extractor = DataExtractor() + results = extractor.extract_from_papers( + papers_with_content, + use_full_text=state["use_full_text"] + ) + + total_points = sum(len(r.data_points) for r in results if r.success) + logs.append(_log(node_name, f"Extracted {total_points} data points from {len(results)} papers")) + + # Serialize results + serialized_results = [] + for r in results: + serialized_results.append({ + "paper_id": r.paper.id if r.paper else "unknown", + "success": r.success, + "error_message": r.error_message, + "data_points": [dp.model_dump() for dp in r.data_points] if r.data_points else [], + }) + + return { + "extraction_results": serialized_results, + "current_node": node_name, + "progress_percent": 75, + "logs": logs, + } + except Exception as e: + logger.exception(f"Extract node failed: {e}") + logs.append(_log(node_name, f"Error: {e}", "error")) + return { + "extraction_results": [], + "current_node": node_name, + "status": "failed", + "error": str(e), + "logs": logs, + } + + +def quality_node(state: LiteratureState) -> dict: + """ + Quality assessment node. + """ + node_name = "quality" + logs = [_log(node_name, "Assessing data quality")] + + try: + # Collect all data points from serialized results + all_points: List[PolymerDataPoint] = [] + for result_dict in state["extraction_results"]: + if result_dict.get("success") and result_dict.get("data_points"): + for dp_dict in result_dict["data_points"]: + try: + dp = PolymerDataPoint(**dp_dict) + all_points.append(dp) + except Exception as e: + logger.warning(f"Failed to deserialize data point: {e}") + + if not all_points: + logs.append(_log(node_name, "No data points to assess", "warning")) + return { + "verified_data": [], + "quality_report": None, + "current_node": node_name, + "progress_percent": 100, + "status": "completed", + "logs": logs, + } + + assessor = QualityAssessor() + verified, report = assessor.assess_batch(all_points) + + logs.append(_log(node_name, report.summary())) + + # Serialize + report_dict = { + "total_points": report.total_points, + "gold_count": report.gold_count, + "silver_count": report.silver_count, + "bronze_count": report.bronze_count, + "invalid_count": report.invalid_count, + "validation_errors": report.validation_errors, + } + + verified_data = [dp.model_dump() for dp in verified] + + return { + "verified_data": verified_data, + "quality_report": report_dict, + "current_node": node_name, + "progress_percent": 100, + "status": "completed", + "logs": logs, + } + except Exception as e: + logger.exception(f"Quality node failed: {e}") + logs.append(_log(node_name, f"Error: {e}", "error")) + return { + "verified_data": [], + "quality_report": None, + "current_node": node_name, + "status": "failed", + "error": str(e), + "logs": logs, + } + + +# ============== Conditional Edges ============== + +def should_continue_after_discover(state: LiteratureState) -> str: + """Should continue after discovery?""" + if state.get("status") == "failed": + return "end" + if not state.get("papers"): + return "end" + return "download" + + +def should_continue_after_download(state: LiteratureState) -> str: + """Should continue after download?""" + if state.get("status") == "failed": + return "end" + if not state.get("downloaded_pdfs") and not state.get("papers"): + return "end" + return "extract" + + +def should_continue_after_extract(state: LiteratureState) -> str: + """Should continue after extraction?""" + if state.get("status") == "failed": + return "end" + + # Check if any extraction succeeded + results = state.get("extraction_results", []) + total_points = sum( + len(r.get("data_points", [])) + for r in results + if r.get("success") + ) + if total_points == 0: + return "end" + return "quality" + + +# ============== Graph Builder ============== + +def create_literature_graph(checkpointer=None): + """ + Create the literature mining workflow graph. + + Args: + checkpointer: Optional checkpoint storage (defaults to MemorySaver) + + Returns: + Compiled LangGraph + """ + builder = StateGraph(LiteratureState) + + # Add nodes + builder.add_node("discover", discover_node) + builder.add_node("download", download_node) + builder.add_node("extract", extract_node) + builder.add_node("quality", quality_node) + + # Add edges + builder.add_edge(START, "discover") + + builder.add_conditional_edges( + "discover", + should_continue_after_discover, + {"download": "download", "end": END} + ) + + builder.add_conditional_edges( + "download", + should_continue_after_download, + {"extract": "extract", "end": END} + ) + + builder.add_conditional_edges( + "extract", + should_continue_after_extract, + {"quality": "quality", "end": END} + ) + + builder.add_edge("quality", END) + + # Compile + if checkpointer is None: + checkpointer = MemorySaver() + + graph = builder.compile(checkpointer=checkpointer) + + return graph + + +# ============== Sync Runner ============== + +def run_workflow( + query: str, + max_papers: int = 10, + use_full_text: bool = False, + thread_id: str = "default", + on_state_update: Optional[Callable[[LiteratureState], None]] = None, +) -> LiteratureState: + """ + Run the literature mining workflow (synchronous). + + Args: + query: Search query + max_papers: Max papers per source + use_full_text: Whether to use full text extraction + thread_id: Thread ID for state recovery + on_state_update: Callback for state updates + + Returns: + Final state + """ + graph = create_literature_graph() + initial_state = create_initial_state(query, max_papers, use_full_text) + + config = {"configurable": {"thread_id": thread_id}} + + final_state = None + for event in graph.stream(initial_state, config, stream_mode="values"): + final_state = event + if on_state_update: + on_state_update(event) + + return final_state diff --git a/literature/property_registry.py b/literature/property_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..59e96ec9471b55f52fd9bdca4e0e0a70124ecdb2 --- /dev/null +++ b/literature/property_registry.py @@ -0,0 +1,274 @@ +""" +Property catalog and extraction prompt builder for production literature mining. + +This registry is aligned to the platform's public property keys so staged +literature evidence can be consumed by Property Probe and Discovery without +ad-hoc remapping. +""" +from __future__ import annotations + +import re +from typing import Dict, List, Optional + + +PROPERTY_CATALOG: Dict[str, Dict[str, str]] = { + # Thermal + "tm": {"name": "Melting temperature", "unit": "K"}, + "tg": {"name": "Glass transition temperature", "unit": "K"}, + "td": {"name": "Thermal diffusivity", "unit": "m^2/s"}, + "tc": {"name": "Thermal conductivity", "unit": "W/(m*K)"}, + "cp": {"name": "Specific heat capacity", "unit": "J/(kg*K)"}, + # Mechanical + "young": {"name": "Young's modulus", "unit": "GPa"}, + "shear": {"name": "Shear modulus", "unit": "GPa"}, + "bulk": {"name": "Bulk modulus", "unit": "GPa"}, + "poisson": {"name": "Poisson ratio", "unit": "dimensionless"}, + # Transport + "visc": {"name": "Viscosity", "unit": "Pa*s"}, + "dif": {"name": "Diffusivity", "unit": "cm^2/s"}, + # Gas permeability + "phe": {"name": "He permeability", "unit": "Barrer"}, + "ph2": {"name": "H2 permeability", "unit": "Barrer"}, + "pco2": {"name": "CO2 permeability", "unit": "Barrer"}, + "pn2": {"name": "N2 permeability", "unit": "Barrer"}, + "po2": {"name": "O2 permeability", "unit": "Barrer"}, + "pch4": {"name": "CH4 permeability", "unit": "Barrer"}, + # Electronic / optical + "alpha": {"name": "Polarizability", "unit": "a.u."}, + "homo": {"name": "HOMO energy", "unit": "eV"}, + "lumo": {"name": "LUMO energy", "unit": "eV"}, + "bandgap": {"name": "Band gap", "unit": "eV"}, + "mu": {"name": "Dipole moment", "unit": "Debye"}, + "etotal": {"name": "Total electronic energy", "unit": "eV"}, + "ri": {"name": "Refractive index", "unit": "dimensionless"}, + "dc": {"name": "Dielectric constant", "unit": "dimensionless"}, + "pe": {"name": "Permittivity", "unit": "dimensionless"}, + # Structural / physical + "rg": {"name": "Radius of gyration", "unit": "Angstrom"}, + "rho": {"name": "Density", "unit": "g/cm^3"}, + # Extended literature-only properties retained for discovery/search + "electrical_conductivity": {"name": "Electrical conductivity", "unit": "S/cm"}, + "seebeck_coefficient": {"name": "Seebeck coefficient", "unit": "uV/K"}, + "power_factor": {"name": "Power factor", "unit": "uW/(m*K^2)"}, + "zt_figure_of_merit": {"name": "ZT figure of merit", "unit": "dimensionless"}, + "tensile_strength": {"name": "Tensile strength", "unit": "MPa"}, + "elongation_at_break": {"name": "Elongation at break", "unit": "%"}, + "crystallinity": {"name": "Crystallinity", "unit": "%"}, +} + + +PLATFORM_PROPERTY_KEYS = [ + "tm", "tg", "td", "tc", "cp", + "young", "shear", "bulk", "poisson", + "visc", "dif", + "phe", "ph2", "pco2", "pn2", "po2", "pch4", + "alpha", "homo", "lumo", "bandgap", "mu", "etotal", "ri", "dc", "pe", + "rg", "rho", +] + + +TEMPLATES: Dict[str, List[str]] = { + "thermal": ["tm", "tg", "td", "tc", "cp"], + "mechanical": ["young", "shear", "bulk", "poisson", "tensile_strength", "elongation_at_break"], + "electronic": ["bandgap", "homo", "lumo", "ri", "dc", "pe", "alpha", "mu", "etotal"], + "gas_permeability": ["pco2", "po2", "pn2", "ph2", "phe", "pch4"], + "transport": ["visc", "dif", "tc", "electrical_conductivity", "seebeck_coefficient", "power_factor"], + "platform_core": PLATFORM_PROPERTY_KEYS, +} + +TEMPLATE_LABELS: Dict[str, str] = { + "thermal": "Thermal", + "mechanical": "Mechanical", + "electronic": "Electronic / Optical", + "gas_permeability": "Gas Permeability", + "transport": "Transport / Energy", + "platform_core": "Platform Core", +} + + +PROPERTY_ALIASES: Dict[str, str] = { + "thermal conductivity": "tc", + "heat conductivity": "tc", + "thermal diffusivity": "td", + "heat diffusivity": "td", + "specific heat": "cp", + "heat capacity": "cp", + "young modulus": "young", + "youngs modulus": "young", + "young_s_modulus": "young", + "young_modulus": "young", + "shear modulus": "shear", + "shear_modulus": "shear", + "bulk modulus": "bulk", + "bulk_modulus": "bulk", + "poisson ratio": "poisson", + "poisson_ratio": "poisson", + "viscosity": "visc", + "diffusivity": "dif", + "he permeability": "phe", + "helium permeability": "phe", + "h2 permeability": "ph2", + "co2 permeability": "pco2", + "n2 permeability": "pn2", + "o2 permeability": "po2", + "ch4 permeability": "pch4", + "polarizability": "alpha", + "homo energy": "homo", + "lumo energy": "lumo", + "band gap": "bandgap", + "bandgap": "bandgap", + "dipole moment": "mu", + "total electronic energy": "etotal", + "refractive index": "ri", + "dielectric constant": "dc", + "permittivity": "pe", + "radius of gyration": "rg", + "density": "rho", + "electrical conductivity": "electrical_conductivity", + "conductivity": "electrical_conductivity", + "seebeck coefficient": "seebeck_coefficient", + "power factor": "power_factor", + "zt": "zt_figure_of_merit", + "zt figure of merit": "zt_figure_of_merit", + "tensile strength": "tensile_strength", + "elongation at break": "elongation_at_break", + "co2_permeability": "pco2", + "o2_permeability": "po2", + "n2_permeability": "pn2", + "h2_permeability": "ph2", + "he_permeability": "phe", + "ch4_permeability": "pch4", + "radius_of_gyration": "rg", + "refractive_index": "ri", + "dielectric_constant": "dc", + "dipole_moment": "mu", +} + + +def _norm(text: str) -> str: + normalized = re.sub(r"[^a-z0-9]+", " ", str(text or "").strip().lower()) + return re.sub(r"\s+", " ", normalized).strip() + + +for key, meta in PROPERTY_CATALOG.items(): + PROPERTY_ALIASES.setdefault(_norm(key), key) + PROPERTY_ALIASES.setdefault(_norm(meta["name"]), key) + + +def normalize_property_key(value: str | None) -> Optional[str]: + """Map free-form property text to a canonical registry key.""" + if not value: + return None + key = PROPERTY_ALIASES.get(_norm(value)) + if key in PROPERTY_CATALOG: + return key + return None + + +def detect_property_keys(text: str) -> List[str]: + """Return all unique property keys that appear in the free-form text.""" + haystack = _norm(text) + out: List[str] = [] + for alias, key in PROPERTY_ALIASES.items(): + if alias and alias in haystack and key not in out: + out.append(key) + return out + + +def property_display_name(key: str) -> str: + meta = PROPERTY_CATALOG.get(key) + if not meta: + return key + return f"{meta['name']} ({meta['unit']})" + + +def _property_list_block(property_keys: List[str]) -> str: + """Build the target-properties section of the extraction prompt.""" + lines = [] + for key in property_keys: + meta = PROPERTY_CATALOG.get(key) + if meta: + lines.append(f"- `{key}` ({meta['name']}) -- standard unit: {meta['unit']}") + else: + lines.append(f"- `{key}`") + return "\n".join(lines) + + +def build_extraction_prompt( + property_keys: List[str], + extra_instructions: str = "", +) -> str: + """ + Build a dynamic contextualized extraction prompt from the given property list. + """ + normalized_keys = [normalize_property_key(k) or k for k in property_keys if k] + props_block = _property_list_block(normalized_keys) + + extra_section = "" + if extra_instructions.strip(): + extra_section = f""" +## ADDITIONAL CONTEXT + +{extra_instructions.strip()} +""" + + prompt = f"""You are an expert in polymer science and materials characterization. +Extract experimentally grounded evidence records from the provided paper. + +## CRITICAL REQUIREMENTS + +1. Extract each material-property-value observation as a separate record +2. Preserve the original value and unit exactly as written +3. Include experimental conditions and measurement method whenever available +4. Include a source quote and source location for every record +5. Ignore theoretical-only values unless the paper explicitly reports an experiment-backed measurement + +## TARGET PROPERTIES + +For each data point, extract these properties: +{props_block} +{extra_section} +## OUTPUT FORMAT (JSON Array) + +Return ONLY valid JSON, no markdown, no explanation: + +[ + {{ + "polymer_name": "P3HT", + "property_name": "", + "raw_value": "1.9", + "raw_unit": "eV", + "conditions": {{ + "solvent": "chloroform", + "annealing_temp_c": 150, + "annealing_time_min": 10, + "measurement_temp_k": 300, + "measurement_method": "UV-Vis" + }}, + "source_quote": "The optical band gap of P3HT was determined to be 1.9 eV from the UV-Vis absorption onset.", + "source_location": "Table 1", + "extraction_confidence": 0.95 + }} +] + +## RULES + +1. If values range "from X to Y", extract BOTH as separate points +2. Preserve scientific notation as "5.2e3" or actual number +3. If no source quote is available, lower extraction_confidence below 0.5 +4. Prefer experimentally measured values over model predictions or simulations +5. Return ONLY a valid JSON array, no extra text + +--- + +**PAPER CONTENT:** + +Title: {{title}} + +{{content}} + +--- + +JSON output: +""" + return prompt diff --git a/literature/quality.py b/literature/quality.py new file mode 100644 index 0000000000000000000000000000000000000000..daab6660b18793740424b7f8b134ae14bc45d1db --- /dev/null +++ b/literature/quality.py @@ -0,0 +1,176 @@ +""" +Production quality assessment and validation for literature evidence. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +from .schemas import ContextualizedValue, DataQuality, PolymerDataPoint + +logger = logging.getLogger(__name__) + + +@dataclass +class QualityReport: + """Data quality report for a batch of data points.""" + total_points: int + gold_count: int + silver_count: int + bronze_count: int + invalid_count: int + validation_errors: List[str] + + @property + def gold_ratio(self) -> float: + return self.gold_count / max(self.total_points, 1) + + def summary(self) -> str: + return ( + f"Quality Report: {self.total_points} points\n" + f" Gold: {self.gold_count} ({self.gold_ratio:.1%})\n" + f" Silver: {self.silver_count}\n" + f" Bronze: {self.bronze_count}\n" + f" Invalid: {self.invalid_count}\n" + f" Errors: {len(self.validation_errors)}" + ) + + +class QualityAssessor: + """Quality assessor with property-aware sanity checks.""" + + PROPERTY_BOUNDS: Dict[str, Tuple[Optional[float], Optional[float]]] = { + "tm": (50, 2000), + "tg": (50, 2000), + "td": (1e-10, 1.0), + "tc": (1e-4, 1000.0), + "cp": (1.0, 1e7), + "young": (1e-6, 1e5), + "shear": (1e-6, 1e5), + "bulk": (1e-6, 1e5), + "poisson": (-1.0, 0.5), + "visc": (1e-9, 1e9), + "dif": (1e-12, 10.0), + "rho": (1e-6, 100.0), + "ri": (0.5, 10.0), + "bandgap": (-20.0, 20.0), + "homo": (-30.0, 10.0), + "lumo": (-30.0, 20.0), + "mu": (0.0, 1e4), + "electrical_conductivity": (1e-12, 1e8), + "seebeck_coefficient": (-1e5, 1e5), + "power_factor": (0.0, 1e9), + "zt_figure_of_merit": (0.0, 1e4), + } + + def __init__(self) -> None: + self.errors: List[str] = [] + + def assess_batch(self, data_points: List[PolymerDataPoint]) -> Tuple[List[PolymerDataPoint], QualityReport]: + """Legacy compatibility path used by older scripts.""" + self.errors = [] + valid_points: List[PolymerDataPoint] = [] + gold_count = silver_count = bronze_count = invalid_count = 0 + + for dp in data_points: + is_valid, error_msg = self._validate_legacy(dp) + if not is_valid: + self.errors.append(f"{dp.source_paper_id}: {error_msg}") + invalid_count += 1 + continue + + dp.quality_tier = self._compute_legacy_quality_tier(dp) + if dp.quality_tier == DataQuality.GOLD: + gold_count += 1 + elif dp.quality_tier == DataQuality.SILVER: + silver_count += 1 + else: + bronze_count += 1 + valid_points.append(dp) + + report = QualityReport( + total_points=len(data_points), + gold_count=gold_count, + silver_count=silver_count, + bronze_count=bronze_count, + invalid_count=invalid_count, + validation_errors=self.errors.copy(), + ) + logger.info(report.summary()) + return valid_points, report + + def validate_contextual_value(self, value: ContextualizedValue) -> Tuple[bool, Optional[str]]: + if not value.polymer_name or value.polymer_name.strip().lower() == "unknown": + return False, "Missing material name" + if not value.property_name: + return False, "Missing property key" + if value.standardized_value is None: + return False, "Missing standardized value" + if not value.source_quote or len(value.source_quote.strip()) < 10: + return False, "Missing source quote" + + bounds = self.PROPERTY_BOUNDS.get(value.property_name) + if bounds is None: + return True, None + + low, high = bounds + numeric = value.standardized_value + if low is not None and numeric < low: + return False, f"Value below plausible range: {numeric}" + if high is not None and numeric > high: + return False, f"Value above plausible range: {numeric}" + return True, None + + def assess_contextual_quality(self, value: ContextualizedValue) -> DataQuality: + score = 0 + if value.standardized_value is not None: + score += 2 + if value.conditions.to_dict(): + score += min(len(value.conditions.to_dict()), 3) + if value.conditions.measurement_method: + score += 1 + if value.source_location: + score += 1 + if value.extraction_confidence >= 0.9: + score += 2 + elif value.extraction_confidence >= 0.7: + score += 1 + + if score >= 7: + return DataQuality.GOLD + if score >= 4: + return DataQuality.SILVER + return DataQuality.BRONZE + + def _validate_legacy(self, dp: PolymerDataPoint) -> Tuple[bool, Optional[str]]: + if not dp.polymer_name or dp.polymer_name == "Unknown": + return False, "Missing polymer name" + has_measurement = any([ + dp.electrical_conductivity_s_cm is not None, + dp.thermal_conductivity_w_mk is not None, + dp.seebeck_coefficient_uv_k is not None, + ]) + if not has_measurement: + return False, "No measurement values" + return True, None + + def _compute_legacy_quality_tier(self, dp: PolymerDataPoint) -> DataQuality: + score = 0 + if dp.electrical_conductivity_s_cm is not None: + score += 3 + if dp.seebeck_coefficient_uv_k is not None: + score += 2 + if dp.power_factor_uw_m_k2 is not None: + score += 1 + if dp.thermal_conductivity_w_mk is not None: + score += 4 + if dp.source_table_or_figure: + score += 1 + if dp.annealing_temp_c is not None: + score += 1 + if score >= 7: + return DataQuality.GOLD + if score >= 4: + return DataQuality.SILVER + return DataQuality.BRONZE diff --git a/literature/retrieval.py b/literature/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..2bba8f742360943db843df55185b729dcc222f54 --- /dev/null +++ b/literature/retrieval.py @@ -0,0 +1,398 @@ +""" +PDF retrieval module. +Downloads papers from ArXiv (priority) and via Unpaywall. +Implements robust header spoofing and graceful error handling. +""" +import logging +import os +import time +from pathlib import Path +from typing import Optional, List +import requests + +from .schemas import PaperMetadata, PaperSource +from .config import get_config + +logger = logging.getLogger(__name__) + + +class PDFRetriever: + """ + PDF retrieval with robust error handling. + + Priority: + 1. ArXiv (direct, free, reliable) + 2. Existing pdf_url from metadata + 3. Unpaywall via DOI + """ + + def __init__(self) -> None: + config = get_config() + self.storage_dir = Path(config.pdf_storage_dir) + self.storage_dir.mkdir(parents=True, exist_ok=True) + + # Robust headers to avoid 403 + self.headers = { + "User-Agent": config.user_agent, + "Accept": "application/pdf,*/*", + "Accept-Language": "en-US,en;q=0.9", + "Accept-Encoding": "gzip, deflate, br", + "Connection": "keep-alive", + } + + self.timeout = 60 # seconds + self.unpaywall_email = config.pubmed_email + + def retrieve_batch( + self, + papers: List[PaperMetadata], + skip_existing: bool = True + ) -> List[PaperMetadata]: + """ + Download PDFs for a batch of papers. + Updates paper.pdf_path for successful downloads. + Saves all papers and failed downloads to CSVs. + + Args: + papers: List of paper metadata + skip_existing: Skip if PDF already exists + + Returns: + Updated list of papers with pdf_path set where successful + """ + successful_ids: set = set() + failed_papers: List[PaperMetadata] = [] + + for paper in papers: + try: + pdf_path = self.retrieve_single(paper, skip_existing=skip_existing) + if pdf_path: + paper.pdf_path = pdf_path + successful_ids.add(paper.id) + else: + failed_papers.append(paper) + except Exception as e: + logger.warning(f"PDF retrieval failed for {paper.id}: {e}") + failed_papers.append(paper) + + logger.info(f"PDF retrieval complete: {len(successful_ids)} successful, {len(failed_papers)} failed") + + # Save all papers with download status + self._save_all_papers(papers, successful_ids) + + # Save failed downloads for manual retrieval + if failed_papers: + self._save_failed_downloads(failed_papers) + + return papers + + def _save_failed_downloads(self, papers: List[PaperMetadata]) -> None: + """Save failed downloads to CSV for manual retrieval.""" + import csv + from datetime import datetime + + csv_path = self.storage_dir / "failed_downloads.csv" + file_exists = csv_path.exists() + + with open(csv_path, "a", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + + # Write header if new file + if not file_exists: + writer.writerow([ + "timestamp", "paper_id", "title", "source", "doi", "url", "expected_filename" + ]) + + timestamp = datetime.now().isoformat() + for paper in papers: + safe_id = paper.id.replace("/", "_").replace(":", "_") + expected_filename = f"{safe_id}.pdf" + writer.writerow([ + timestamp, + paper.id, + paper.title[:100], # Truncate long titles + paper.source.value, + paper.doi or "", + paper.url or "", + expected_filename + ]) + + logger.info(f"Saved {len(papers)} failed downloads to {csv_path}") + + def _save_all_papers( + self, + papers: List[PaperMetadata], + successful_ids: set + ) -> None: + """Save all discovered papers to CSV with download status.""" + import csv + from datetime import datetime + + csv_path = self.storage_dir / "all_papers.csv" + + with open(csv_path, "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow([ + "paper_id", "title", "source", "year", "doi", "url", + "pdf_downloaded", "pdf_path", "timestamp" + ]) + + timestamp = datetime.now().isoformat() + for paper in papers: + downloaded = paper.id in successful_ids or paper.pdf_path is not None + writer.writerow([ + paper.id, + paper.title[:150], + paper.source.value, + paper.year or "", + paper.doi or "", + paper.url or "", + "YES" if downloaded else "NO", + paper.pdf_path or "", + timestamp + ]) + + logger.info(f"Saved {len(papers)} papers to {csv_path}") + + def retrieve_single( + self, + paper: PaperMetadata, + skip_existing: bool = True + ) -> Optional[str]: + """ + Download PDF for a single paper. + + Args: + paper: Paper metadata + skip_existing: Skip if file already exists + + Returns: + Path to downloaded PDF, or None if failed + """ + # Determine filename + safe_id = paper.id.replace("/", "_").replace(":", "_") + pdf_filename = f"{safe_id}.pdf" + pdf_path = self.storage_dir / pdf_filename + + # Check if already exists + if skip_existing and pdf_path.exists(): + logger.debug(f"PDF already exists: {pdf_path}") + return str(pdf_path) + + # Try download methods in priority order + pdf_url = self._get_pdf_url(paper) + + if pdf_url: + success = self._download_pdf(pdf_url, pdf_path) + if success: + logger.info(f"Downloaded PDF: {pdf_path}") + return str(pdf_path) + + logger.warning(f"Could not download PDF for {paper.id}") + return None + + def _get_pdf_url(self, paper: PaperMetadata) -> Optional[str]: + """ + Get PDF URL using priority order: + 1. ArXiv direct link + 2. PubMed Central (PMC) for PubMed papers + 3. Existing pdf_url from metadata + 4. Unpaywall via DOI + """ + # Priority 1: ArXiv (most reliable, free) + if paper.source == PaperSource.ARXIV: + arxiv_id = paper.id.replace("arxiv_", "") + return f"https://arxiv.org/pdf/{arxiv_id}.pdf" + + # Priority 2: PubMed - try PMC first + if paper.source == PaperSource.PUBMED: + pmc_url = self._get_pmc_pdf_url(paper) + if pmc_url: + return pmc_url + + # Priority 3: Use existing pdf_url if available + if paper.pdf_url: + return paper.pdf_url + + # Priority 4: Try Unpaywall via DOI (works for all sources) + if paper.doi: + unpaywall_url = self._get_unpaywall_url(paper.doi) + if unpaywall_url: + return unpaywall_url + + return None + + def _get_pmc_pdf_url(self, paper: PaperMetadata) -> Optional[str]: + """ + Try to get PDF from PubMed Central (PMC). + PMC provides free full-text PDFs for many PubMed articles. + """ + try: + pmid = paper.id.replace("pubmed_", "") + + # Try elink to get PMC ID + from Bio import Entrez + handle = Entrez.elink(dbfrom="pubmed", db="pmc", id=pmid) + record = Entrez.read(handle) + handle.close() + + # Check if PMC ID exists + link_sets = record[0].get("LinkSetDb", []) + for link_set in link_sets: + if link_set.get("DbTo") == "pmc": + links = link_set.get("Link", []) + if links: + pmc_id = links[0]["Id"] + # PMC PDF URL format + return f"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC{pmc_id}/pdf/" + + return None + + except Exception as e: + logger.debug(f"PMC lookup failed for {paper.id}: {e}") + return None + + def _get_unpaywall_url(self, doi: str) -> Optional[str]: + """ + Query Unpaywall API for open-access PDF URL. + + Args: + doi: Paper DOI + + Returns: + PDF URL if found, None otherwise + """ + try: + url = f"https://api.unpaywall.org/v2/{doi}" + params = {"email": self.unpaywall_email} + + response = requests.get( + url, + params=params, + headers=self.headers, + timeout=30 + ) + + if response.status_code != 200: + logger.debug(f"Unpaywall returned {response.status_code} for {doi}") + return None + + data = response.json() + + # Check for best open access location + best_oa = data.get("best_oa_location") + if best_oa and best_oa.get("url_for_pdf"): + return best_oa["url_for_pdf"] + + # Check all OA locations + oa_locations = data.get("oa_locations", []) + for loc in oa_locations: + if loc.get("url_for_pdf"): + return loc["url_for_pdf"] + + return None + + except Exception as e: + logger.debug(f"Unpaywall query failed for {doi}: {e}") + return None + + def _download_pdf(self, url: str, save_path: Path) -> bool: + """ + Download PDF from URL with robust error handling. + + Args: + url: PDF URL + save_path: Local path to save file + + Returns: + True if successful, False otherwise + """ + try: + logger.debug(f"Downloading PDF from: {url}") + + response = requests.get( + url, + headers=self.headers, + timeout=self.timeout, + stream=True, + allow_redirects=True + ) + + # Check for success + if response.status_code != 200: + logger.warning(f"Download failed with status {response.status_code}: {url}") + return False + + # Verify it's a PDF (check content-type or magic bytes) + content_type = response.headers.get("content-type", "").lower() + if "pdf" not in content_type and "octet-stream" not in content_type: + # Check magic bytes as fallback + first_bytes = response.content[:8] + if not first_bytes.startswith(b"%PDF"): + logger.warning(f"Response is not a PDF: {content_type}") + return False + + # Save to file + with open(save_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + # Verify file was written + if save_path.exists() and save_path.stat().st_size > 0: + return True + + return False + + except requests.exceptions.Timeout: + logger.warning(f"Download timeout: {url}") + return False + except requests.exceptions.RequestException as e: + logger.warning(f"Download error: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error downloading {url}: {e}") + return False + + +def extract_text_from_pdf(pdf_path: str, max_pages: int = 100) -> Optional[str]: + """ + Extract text from PDF using pymupdf. + + Args: + pdf_path: Path to PDF file + max_pages: Maximum pages to extract (default 5) + + Returns: + Extracted text, or None if failed + """ + try: + import pymupdf # fitz + except ImportError: + try: + import fitz as pymupdf + except ImportError: + logger.error("pymupdf not installed. Run: pip install pymupdf") + return None + + try: + doc = pymupdf.open(pdf_path) + text_parts: List[str] = [] + + pages_to_extract = min(len(doc), max_pages) + + for page_num in range(pages_to_extract): + page = doc[page_num] + text = page.get_text() + if text: + text_parts.append(f"--- Page {page_num + 1} ---\n{text}") + + doc.close() + + full_text = "\n\n".join(text_parts) + logger.info(f"Extracted {len(full_text)} chars from {pages_to_extract} pages of {pdf_path}") + + return full_text if full_text.strip() else None + + except Exception as e: + logger.error(f"PDF text extraction failed for {pdf_path}: {e}") + return None diff --git a/literature/schemas.py b/literature/schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcdfee2426200aed2d8b63375e698e42694128f --- /dev/null +++ b/literature/schemas.py @@ -0,0 +1,329 @@ +""" +Domain-specific data models for literature mining. +Supports contextualized extraction with source traceability. +""" +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict +from datetime import datetime +from enum import Enum + + +class DataQuality(str, Enum): + """Data quality tier.""" + GOLD = "gold" # Complete data with source quote + SILVER = "silver" # Partial data with source + BRONZE = "bronze" # Limited data or no source + ERROR = "error" # Extraction failed + + +class QueryMode(str, Enum): + """High-level search entrypoint modes.""" + MATERIAL = "material-first" + PROPERTY = "property-first" + TASK = "task-first" + + +class ReviewStatus(str, Enum): + """Human review status for staged evidence.""" + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + + +class PaperSource(str, Enum): + """Paper source identifier.""" + PUBMED = "pubmed" + ARXIV = "arxiv" + SEMANTIC_SCHOLAR = "s2" + MANUAL = "manual" + UNKNOWN = "unknown" + + +class PaperMetadata(BaseModel): + """Paper metadata from discovery.""" + id: str = Field(..., description="Unique ID, format: {source}_{original_id}") + title: str + authors: List[str] = Field(default_factory=list) + year: Optional[int] = None + doi: Optional[str] = None + abstract: Optional[str] = None + venue: Optional[str] = None + citation_count: Optional[int] = None + is_open_access: Optional[bool] = None + source: PaperSource = PaperSource.UNKNOWN + url: Optional[str] = None + landing_url: Optional[str] = None + pdf_url: Optional[str] = None + pdf_path: Optional[str] = None + full_text: Optional[str] = None + match_reasons: List[str] = Field(default_factory=list) + background_status: Optional[str] = None + retrieved_at: datetime = Field(default_factory=datetime.now) + + @field_validator('id') + @classmethod + def validate_id_format(cls, v: str) -> str: + """Ensure ID format is correct.""" + valid_prefixes = ['pubmed_', 'arxiv_', 's2_', 'manual_'] + if not any(v.startswith(p) for p in valid_prefixes): + raise ValueError(f"ID must start with one of {valid_prefixes}") + return v + + +class LiteratureQuerySpec(BaseModel): + """Normalized query payload used by the production literature UI.""" + mode: QueryMode + user_query: str + polymer_name: Optional[str] = None + canonical_smiles: Optional[str] = None + property_key: Optional[str] = None + project_id: Optional[str] = None + top_k_extract: int = Field(default=10, ge=1, le=50) + result_limit: int = Field(default=15, ge=1, le=100) + + +class PaperCardResult(BaseModel): + """User-facing paper card summary.""" + paper_id: str + title: str + year: Optional[int] = None + venue: Optional[str] = None + doi: Optional[str] = None + landing_url: Optional[str] = None + pdf_url: Optional[str] = None + is_open_access: bool = False + match_reasons: List[str] = Field(default_factory=list) + background_status: str = "discovered" + + +class LiteratureSupportSummary(BaseModel): + """Aggregated evidence coverage for a material/property view.""" + matched_paper_count: int = 0 + oa_paper_count: int = 0 + evidence_record_count: int = 0 + approved_record_count: int = 0 + has_experimental_evidence: bool = False + literature_support_score: int = Field(default=0, ge=0, le=100) + + +class LiteratureEvidenceRecord(BaseModel): + """Production staging record for extracted literature evidence.""" + id: Optional[str] = None + project_id: Optional[str] = None + paper_id: str + material_name: str + canonical_smiles: Optional[str] = None + property_key: str + raw_value: str + raw_unit: str + standardized_value: Optional[float] = None + standardized_unit: Optional[str] = None + conditions_json: Dict[str, Any] = Field(default_factory=dict) + method: Optional[str] = None + evidence_quote: str + evidence_location: Optional[str] = None + extractor_version: str + extraction_model: Optional[str] = None + extraction_confidence: float = Field(default=0.5, ge=0.0, le=1.0) + quality_tier: DataQuality = DataQuality.BRONZE + review_status: ReviewStatus = ReviewStatus.PENDING + reviewer_note: Optional[str] = None + edited_payload_json: Optional[Dict[str, Any]] = None + created_at: Optional[str] = None + updated_at: Optional[str] = None + + @field_validator("evidence_quote") + @classmethod + def validate_evidence_quote(cls, v: str) -> str: + text = str(v or "").strip() + if len(text) < 10: + raise ValueError("evidence_quote must be at least 10 characters") + return text + + +# ============== Experimental Conditions ============== + +class ExperimentalConditions(BaseModel): + """ + Experimental conditions with full context. + + ⚠️ extra="allow" keeps LLM-returned fields like humidity, substrate, etc. + """ + model_config = ConfigDict(extra="allow") + + # Preparation conditions + solvent: Optional[str] = None + concentration_mg_ml: Optional[float] = None + spin_speed_rpm: Optional[int] = None + spin_time_s: Optional[int] = None + annealing_temp_c: Optional[float] = None + annealing_time_min: Optional[float] = None + annealing_atmosphere: Optional[str] = None + film_thickness_nm: Optional[float] = None + + # Measurement conditions + measurement_temp_k: Optional[float] = Field(None, description="Measurement temperature (K)") + measurement_method: Optional[str] = None + measurement_direction: Optional[str] = None # in-plane, cross-plane + + def to_dict(self) -> dict: + """Convert to dict, excluding None values.""" + return {k: v for k, v in self.model_dump().items() if v is not None} + + +# ============== Contextualized Value ============== + +class ContextualizedValue(BaseModel): + """ + Measurement value with full experimental context and source traceability. + + Design principles: + - Same paper may report multiple values under different conditions + - Each value MUST have its associated experimental conditions + - MANDATORY: source_quote for traceability + """ + model_config = ConfigDict(extra="allow") + + # Material + polymer_name: str = Field(..., description="Polymer name e.g. PEDOT:PSS") + dopant: Optional[str] = None + dopant_ratio: Optional[str] = None + + # Property measured + property_name: str = Field(..., description="Property name e.g. electrical_conductivity") + + # Raw value + raw_value: str = Field(..., description="Raw value string from paper") + raw_unit: str = Field(..., description="Original unit from paper") + + # Standardized value (filled by Standardizer) + standardized_value: Optional[float] = None + standardized_unit: Optional[str] = None + standardization_error: Optional[str] = None + + # Experimental conditions + conditions: ExperimentalConditions = Field(default_factory=ExperimentalConditions) + + # Source traceability (MANDATORY!) + source_quote: str = Field(..., description="Exact quote from paper containing this value") + source_location: Optional[str] = Field(None, description="Table 1, Figure 3a, etc.") + + # Quality + extraction_confidence: float = Field(default=0.5, ge=0.0, le=1.0) + quality_tier: DataQuality = DataQuality.BRONZE + + @field_validator('source_quote') + @classmethod + def quote_not_empty(cls, v: str) -> str: + if not v or len(v.strip()) < 10: + raise ValueError("source_quote must be >10 chars") + return v.strip() + + def to_db_dict(self) -> dict: + """Convert to database storage format.""" + return { + "polymer_name": self.polymer_name, + "dopant": self.dopant, + "dopant_ratio": self.dopant_ratio, + "property_name": self.property_name, + "raw_value": self.raw_value, + "raw_unit": self.raw_unit, + "standardized_value": self.standardized_value, + "standardized_unit": self.standardized_unit, + "conditions": self.conditions.to_dict(), + "source_quote": self.source_quote, + "source_location": self.source_location, + "extraction_confidence": self.extraction_confidence, + "quality_tier": self.quality_tier.value, + } + + +# ============== Legacy PolymerDataPoint (for compatibility) ============== + +class PolymerDataPoint(BaseModel): + """Single data point extracted from literature (legacy format).""" + # Material Information + polymer_name: str = Field(..., description="Polymer name, e.g. P3HT, PEDOT:PSS") + polymer_class: Optional[str] = Field(None, description="Polymer class") + dopant: Optional[str] = None + dopant_ratio: Optional[str] = None + + # Processing Conditions + solvent: Optional[str] = None + concentration_mg_ml: Optional[float] = None + spin_speed_rpm: Optional[int] = None + spin_time_s: Optional[int] = None + annealing_temp_c: Optional[float] = None + annealing_time_min: Optional[float] = None + annealing_atmosphere: Optional[str] = None + film_thickness_nm: Optional[float] = None + + # Electrical Properties + electrical_conductivity_s_cm: Optional[float] = None + seebeck_coefficient_uv_k: Optional[float] = None + power_factor_uw_m_k2: Optional[float] = None + + # Thermal Properties + thermal_conductivity_w_mk: Optional[float] = None + zt_figure_of_merit: Optional[float] = None + + # Structural + xrd_crystallinity_percent: Optional[float] = None + xrd_pi_stacking_angstrom: Optional[float] = None + xrd_lamellar_spacing_angstrom: Optional[float] = None + + # Metadata + source_paper_id: str + source_table_or_figure: Optional[str] = None + extraction_confidence: float = Field(default=0.5, ge=0.0, le=1.0) + quality_tier: DataQuality = DataQuality.BRONZE + raw_text_snippet: Optional[str] = None + + @field_validator('electrical_conductivity_s_cm', 'thermal_conductivity_w_mk', mode='before') + @classmethod + def validate_positive(cls, v: Any) -> Optional[float]: + if v is not None and isinstance(v, (int, float)) and v < 0: + return None + return v + + +# ============== Extraction Result ============== + +class ExtractionResult(BaseModel): + """ + Extraction result for a single paper. + + Supports both old format (paper=PaperMetadata) and new format (paper_id, paper_title). + """ + model_config = ConfigDict(extra="allow") + + # New format fields (preferred) + paper_id: Optional[str] = None + paper_title: Optional[str] = None + + # Old format field (for backward compatibility) + paper: Optional[PaperMetadata] = None + + # Common fields + data_points: List = Field(default_factory=list) # Can be ContextualizedValue or PolymerDataPoint + extraction_model: str = "unknown" + extraction_timestamp: Any = Field(default_factory=lambda: datetime.now().isoformat()) + success: bool = True + error_message: Optional[str] = None + + # Legacy fields + llm_model_used: Optional[str] = None + extraction_notes: Optional[str] = None + + @model_validator(mode='after') + def extract_paper_fields(self): + """Extract paper_id and paper_title from paper if not provided.""" + if self.paper is not None: + if self.paper_id is None: + self.paper_id = self.paper.id + if self.paper_title is None: + self.paper_title = self.paper.title + # Copy llm_model_used to extraction_model if present + if self.llm_model_used and self.extraction_model == "unknown": + self.extraction_model = self.llm_model_used + return self diff --git a/literature/standardizer.py b/literature/standardizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0407786556098ce6244755ca40581ca3bfc1cca2 --- /dev/null +++ b/literature/standardizer.py @@ -0,0 +1,211 @@ +""" +Unit standardization for production literature evidence. + +The standard units are aligned with the platform property catalog so extracted +evidence can be compared and filtered consistently before human review. +""" +import logging +import re +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional + +from .property_registry import PROPERTY_CATALOG + +logger = logging.getLogger(__name__) + + +@dataclass +class StandardizationResult: + """Standardization result.""" + success: bool + value: Optional[float] = None + unit: Optional[str] = None + error: Optional[str] = None + + +def normalize_minus_signs(s: str) -> str: + """Normalize all Unicode minus signs to ASCII hyphen-minus.""" + minus_chars = [ + "−", "–", "—", "‐", "‑", "‒", "⁻", "₋", "➖", + ] + for char in minus_chars: + s = s.replace(char, "-") + return s + + +def _identity(value: float) -> float: + return value + + +def _mul(factor: float) -> Callable[[float], float]: + return lambda value: value * factor + + +def _add(delta: float) -> Callable[[float], float]: + return lambda value: value + delta + + +class UnitStandardizer: + """Convert raw values from papers to platform-standard units.""" + + STANDARD_UNITS = {key: meta["unit"] for key, meta in PROPERTY_CATALOG.items()} + + UNIT_ALIASES = { + # Temperature + "k": "K", + "kelvin": "K", + "c": "C", + "°c": "C", + "deg c": "C", + "celsius": "C", + # Thermal + "w/mk": "W/(m*K)", + "w/(m·k)": "W/(m*K)", + "w m-1 k-1": "W/(m*K)", + "w·m⁻¹·k⁻¹": "W/(m*K)", + "mw/(m*k)": "mW/(m*K)", + "mw/(m·k)": "mW/(m*K)", + "j/kgk": "J/(kg*K)", + "j/(kg·k)": "J/(kg*K)", + "j/(kg*k)": "J/(kg*K)", + "j/gk": "J/(g*K)", + "j/(g*k)": "J/(g*K)", + # Mechanical + "gpa": "GPa", + "mpa": "MPa", + # Transport / physical + "pa s": "Pa*s", + "pa·s": "Pa*s", + "pas": "Pa*s", + "mpa*s": "mPa*s", + "cm2/s": "cm^2/s", + "cm^2/s": "cm^2/s", + "mm2/s": "mm^2/s", + "mm^2/s": "mm^2/s", + "g/cm3": "g/cm^3", + "g/cm^3": "g/cm^3", + "kg/m3": "kg/m^3", + "kg/m^3": "kg/m^3", + "ang": "Angstrom", + "angstrom": "Angstrom", + "å": "Angstrom", + "nm": "nm", + # Electronics + "ev": "eV", + "a.u.": "a.u.", + "au": "a.u.", + "debye": "Debye", + # Gas / transport + "barrer": "Barrer", + # Extended literature properties + "s/cm": "S/cm", + "s m-1": "S/m", + "s/m": "S/m", + "uv/k": "uV/K", + "μv/k": "uV/K", + "µv/k": "uV/K", + "mv/k": "mV/K", + "uw/(m*k^2)": "uW/(m*K^2)", + "uw/(m*k**2)": "uW/(m*K^2)", + "uw/(m·k²)": "uW/(m*K^2)", + "mw/(m*k^2)": "mW/(m*K^2)", + "%": "%", + "dimensionless": "", + "-": "", + "": "", + } + + CONVERSIONS: Dict[str, Dict[tuple[str, str], Callable[[float], float]]] = { + "tm": {("C", "K"): _add(273.15)}, + "tg": {("C", "K"): _add(273.15)}, + "cp": {("J/(g*K)", "J/(kg*K)"): _mul(1000.0)}, + "tc": {("mW/(m*K)", "W/(m*K)"): _mul(0.001)}, + "young": {("MPa", "GPa"): _mul(0.001)}, + "shear": {("MPa", "GPa"): _mul(0.001)}, + "bulk": {("MPa", "GPa"): _mul(0.001)}, + "visc": {("mPa*s", "Pa*s"): _mul(0.001)}, + "dif": {("mm^2/s", "cm^2/s"): _mul(0.01)}, + "rho": {("kg/m^3", "g/cm^3"): _mul(0.001)}, + "rg": {("nm", "Angstrom"): _mul(10.0)}, + "electrical_conductivity": {("S/m", "S/cm"): _mul(0.01)}, + "seebeck_coefficient": {("mV/K", "uV/K"): _mul(1000.0)}, + "power_factor": {("mW/(m*K^2)", "uW/(m*K^2)"): _mul(1000.0)}, + } + + def standardize( + self, + property_name: str, + raw_value: str, + raw_unit: str, + ) -> StandardizationResult: + try: + numeric = self._parse_numeric(raw_value) + except ValueError as exc: + return StandardizationResult(success=False, error=f"Parse error: {exc}") + + standard_unit = self.STANDARD_UNITS.get(property_name) + if standard_unit is None: + return StandardizationResult(success=False, error=f"Unknown property: {property_name}") + + normalized = self._normalize_unit(raw_unit) + if standard_unit in {"dimensionless", ""}: + return StandardizationResult(success=True, value=numeric, unit="") + + if normalized == standard_unit: + return StandardizationResult(success=True, value=numeric, unit=standard_unit) + + transform = self.CONVERSIONS.get(property_name, {}).get((normalized, standard_unit)) + if transform is not None: + return StandardizationResult(success=True, value=transform(numeric), unit=standard_unit) + + if normalized == "": + return StandardizationResult(success=False, error=f"Missing unit for {property_name}") + + return StandardizationResult( + success=False, + error=f"Cannot convert {normalized} to {standard_unit} for {property_name}", + ) + + def _parse_numeric(self, value_str: str) -> float: + s = normalize_minus_signs(str(value_str or "").strip()) + s = re.sub(r"\s*[×x]\s*10\^?\s*(-?\d+)", r"e\1", s) + superscripts = { + "⁰": "0", "¹": "1", "²": "2", "³": "3", "⁴": "4", + "⁵": "5", "⁶": "6", "⁷": "7", "⁸": "8", "⁹": "9", "⁻": "-", + } + for sup, norm in superscripts.items(): + s = s.replace(sup, norm) + s = s.replace(" ", "") + + range_match = re.match(r"^(\d+(?:\.\d+)?)\s*-\s*(\d+(?:\.\d+)?)$", s) + if range_match: + low = float(range_match.group(1)) + high = float(range_match.group(2)) + return (low + high) / 2 + + pm_match = re.match(r"^([\d.eE+-]+)\s*[±]\s*[\d.eE+-]+$", s) + if pm_match: + return float(pm_match.group(1)) + + return float(s) + + def _normalize_unit(self, unit: str) -> str: + normalized = normalize_minus_signs(str(unit or "").strip()) + normalized = normalized.replace("²", "^2").replace("³", "^3") + normalized = normalized.replace("·", "*").replace(" ", " ") + key = re.sub(r"\s+", " ", normalized.lower()).strip() + return self.UNIT_ALIASES.get(key, normalized) + + def standardize_data_points(self, data_points: List) -> List: + for dp in data_points: + result = self.standardize( + property_name=dp.property_name, + raw_value=dp.raw_value, + raw_unit=dp.raw_unit, + ) + if result.success: + dp.standardized_value = result.value + dp.standardized_unit = result.unit + else: + dp.standardization_error = result.error + return data_points diff --git a/scripts/__pycache__/run_literature_mining.cpython-313.pyc b/scripts/__pycache__/run_literature_mining.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cc842d8919cb18dfc6c4952739cdb0db192bdee Binary files /dev/null and b/scripts/__pycache__/run_literature_mining.cpython-313.pyc differ diff --git a/scripts/evaluate_polyie.py b/scripts/evaluate_polyie.py new file mode 100644 index 0000000000000000000000000000000000000000..b43ff0ac38ffdc2958fb6851cbc84d64b3c4703a --- /dev/null +++ b/scripts/evaluate_polyie.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from literature.evaluation import evaluate_predictions, load_json_records + + +def main() -> None: + parser = argparse.ArgumentParser(description="Evaluate extraction output against a POLYIE-style gold file.") + parser.add_argument("--gold", required=True, help="Gold file (.json or .jsonl)") + parser.add_argument("--pred", required=True, help="Prediction file (.json or .jsonl)") + parser.add_argument("--out", default=None, help="Optional JSON output path") + args = parser.parse_args() + + gold_records = load_json_records(args.gold) + predicted_records = load_json_records(args.pred) + metrics = evaluate_predictions(gold_records, predicted_records) + text = json.dumps(metrics, indent=2, ensure_ascii=False) + print(text) + + if args.out: + Path(args.out).write_text(text + "\n", encoding="utf-8") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_literature_mining.py b/scripts/run_literature_mining.py new file mode 100644 index 0000000000000000000000000000000000000000..79ff1a87d4581f8a0befb126e3c477729b94317b --- /dev/null +++ b/scripts/run_literature_mining.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Project-based literature mining CLI. + +Examples: + python scripts/run_literature_mining.py --query "PEDOT:PSS thermoelectric" --limit 5 + python scripts/run_literature_mining.py --project-id proj_xxx --query "P3HT conductivity" --save-mode files +""" +from __future__ import annotations + +import argparse +import csv +import json +from pathlib import Path +from typing import Any, Dict, List + +from dotenv import load_dotenv + +from src.literature_service import ( + DataPointRepo, + LiteraturePipeline, + ProjectRepo, + QueryIntentService, + QuerySessionRepo, + get_database, +) + +load_dotenv() + + +def resolve_project_id(project_id: str | None, projects: ProjectRepo) -> str: + if project_id: + project = projects.get_project(project_id) + if not project: + raise ValueError(f"Project not found: {project_id}") + return project_id + + existing = projects.list_projects() + if existing: + return existing[0]["id"] + + created = projects.create_project( + name="Default Literature Project", + description="Auto-created by run_literature_mining.py", + ) + return created["id"] + + +def export_points_to_files(project_id: str, points: List[Dict[str, Any]], out_dir: Path) -> None: + out_dir.mkdir(parents=True, exist_ok=True) + + jsonl_path = out_dir / "validated_points.jsonl" + with jsonl_path.open("w", encoding="utf-8") as f: + for row in points: + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + csv_path = out_dir / "validated_points.csv" + if points: + with csv_path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=list(points[0].keys())) + writer.writeheader() + writer.writerows(points) + else: + csv_path.write_text("point_id,project_id\n", encoding="utf-8") + + print(f"Exported {len(points)} rows to:") + print(f" - {jsonl_path}") + print(f" - {csv_path}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Project-based Literature Mining CLI") + parser.add_argument("--project-id", default=None, help="Target project ID") + parser.add_argument("--query", default="PEDOT:PSS thermoelectric conductivity", help="Search query") + parser.add_argument("--limit", type=int, default=5, help="Max papers per source") + parser.add_argument("--strategy", choices=["simple", "paperqa"], default="simple", help="Extraction strategy") + parser.add_argument("--model-provider", default="openai_compatible", help="Model provider name") + parser.add_argument("--model-name", default="gpt-oss:latest", help="Model name") + parser.add_argument("--save-mode", choices=["sqlite", "files"], default="sqlite", help="Result sink mode") + parser.add_argument("--no-save", action="store_true", help="Do not persist result to sqlite") + parser.add_argument("--manual-upload-dir", default="data/literature/manual_uploads", help="Reserved for batch manual upload") + args = parser.parse_args() + + db = get_database("data/app.db") + project_repo = ProjectRepo(db) + point_repo = DataPointRepo(db) + query_repo = QuerySessionRepo(db) + query_intent = QueryIntentService(query_repo) + pipeline = LiteraturePipeline(db_path="data/app.db") + + target_project_id = resolve_project_id(args.project_id, project_repo) + project = project_repo.get_project(target_project_id) + print("=" * 64) + print("Project-Based Literature Mining") + print(f"Project: {project['name']} ({target_project_id})") + print(f"Query: {args.query}") + print(f"Limit per source: {args.limit}") + print(f"Strategy: {args.strategy}") + print("=" * 64) + + query_session = query_intent.analyze_and_store(target_project_id, args.query) + suggestions = json.loads(query_session.get("suggestions_json") or "[]") + if suggestions: + print("Query suggestions:") + for s in suggestions: + print(f" - {s}") + if query_session.get("clarification_required"): + print("Note: query marked as pending_clarification. Continuing by CLI override.") + + if args.no_save: + discovered = pipeline.run_discovery(target_project_id, args.query, args.limit) + retrieved = pipeline.run_retrieval(target_project_id, discovered) + stats = pipeline.run_extraction( + target_project_id, + run_id=None, + paper_rows=retrieved, + strategy=args.strategy, + model_name=args.model_name, + use_full_text=True, + ) + print(f"Extraction complete without DB run record: {stats}") + else: + result = pipeline.run_full_pipeline( + project_id=target_project_id, + query=args.query, + limit=args.limit, + strategy=args.strategy, + model_provider=args.model_provider, + model_name=args.model_name, + use_full_text=True, + ) + print(f"Pipeline status: {result.get('status')}") + if result.get("status") != "completed": + print(f"Error: {result.get('error')}") + else: + print(json.dumps(result.get("stats", {}), indent=2)) + + points = point_repo.list_points(target_project_id) + if args.save_mode == "files": + run_dir = Path("data/literature/runs") + export_points_to_files(target_project_id, points, run_dir) + + print("=" * 64) + print("Done.") + print("=" * 64) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_prior_slurm.sh b/scripts/train_prior_slurm.sh new file mode 100644 index 0000000000000000000000000000000000000000..28d390d29ca5151637429c6ba8dac5a7faf8baea --- /dev/null +++ b/scripts/train_prior_slurm.sh @@ -0,0 +1,38 @@ +#!/bin/bash +#SBATCH --job-name=polymer_prior +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G +#SBATCH --time=24:00:00 +#SBATCH --output=logs/train_prior_%j.out +#SBATCH --error=logs/train_prior_%j.err + +set -euo pipefail + +# Adjust these for your CRC environment +REPO_DIR="/Users/xuguoyue/Documents/GitHub/POLYMER-PROPERTY" +VENV_DIR="$REPO_DIR/.venv" + +cd "$REPO_DIR" + +# Load modules if your CRC requires it (example) +# module load python/3.10 + +source "$VENV_DIR/bin/activate" + +mkdir -p logs + +export OMP_NUM_THREADS=8 +export MKL_NUM_THREADS=8 + +torchrun --nproc_per_node=4 RNN/train_prior.py \ + --smiles-csv data/PI1M.csv \ + --vocab RNN/pretrained_model/voc \ + --output RNN/pretrained_model/Prior.ckpt \ + --epochs 10 \ + --batch-size 256 \ + --lr 1e-3 \ + --max-length 140 \ + --num-workers 4 \ + --log-every 200 diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..415cd4aed244ebdcdebd9aacd740f8220345bb29 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/__pycache__/conv.cpython-310.pyc b/src/__pycache__/conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39b97b1c6860e1f34c17846e81cb6f8f0f6496df Binary files /dev/null and b/src/__pycache__/conv.cpython-310.pyc differ diff --git a/src/__pycache__/conv.cpython-313.pyc b/src/__pycache__/conv.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb8afdc5bcac63616898ea7a3140a18dc56c2092 Binary files /dev/null and b/src/__pycache__/conv.cpython-313.pyc differ diff --git a/src/__pycache__/data_builder.cpython-310.pyc b/src/__pycache__/data_builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30a0abad183dd1b38b7cac2e8c3784fc1e406d1a Binary files /dev/null and b/src/__pycache__/data_builder.cpython-310.pyc differ diff --git a/src/__pycache__/data_builder.cpython-313.pyc b/src/__pycache__/data_builder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1531e70801c01aa4b55f73a9e6bdd67951b3d03d Binary files /dev/null and b/src/__pycache__/data_builder.cpython-313.pyc differ diff --git a/src/__pycache__/discover_llm.cpython-310.pyc b/src/__pycache__/discover_llm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d8c859e98b24e5647516dacdc528864ef34178b Binary files /dev/null and b/src/__pycache__/discover_llm.cpython-310.pyc differ diff --git a/src/__pycache__/discover_llm.cpython-313.pyc b/src/__pycache__/discover_llm.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0de34c6c5d8aadbe2d7e7b2abe1a513e03224167 Binary files /dev/null and b/src/__pycache__/discover_llm.cpython-313.pyc differ diff --git a/src/__pycache__/discovery.cpython-310.pyc b/src/__pycache__/discovery.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5eef6ac4313f1ec99c3457a05ef7d8adb745ddd Binary files /dev/null and b/src/__pycache__/discovery.cpython-310.pyc differ diff --git a/src/__pycache__/discovery.cpython-313.pyc b/src/__pycache__/discovery.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e8b67efec459ffeddc154f7c2e4e66fcf2feb3c Binary files /dev/null and b/src/__pycache__/discovery.cpython-313.pyc differ diff --git a/src/__pycache__/literature_ui.cpython-310.pyc b/src/__pycache__/literature_ui.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aa36a2c9afc954218add69e855c5bba642f38d8 Binary files /dev/null and b/src/__pycache__/literature_ui.cpython-310.pyc differ diff --git a/src/__pycache__/lookup.cpython-310.pyc b/src/__pycache__/lookup.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d940dfceb6612409516640520e5e69dd2b437f7a Binary files /dev/null and b/src/__pycache__/lookup.cpython-310.pyc differ diff --git a/src/__pycache__/lookup.cpython-313.pyc b/src/__pycache__/lookup.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c45a01818852bcb4b1d9f69f0522f559b4785a76 Binary files /dev/null and b/src/__pycache__/lookup.cpython-313.pyc differ diff --git a/src/__pycache__/model.cpython-310.pyc b/src/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecb4a17c5caeb03993ec58259a65cd5a1b69674c Binary files /dev/null and b/src/__pycache__/model.cpython-310.pyc differ diff --git a/src/__pycache__/model.cpython-313.pyc b/src/__pycache__/model.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25daab955a6b260241f850f2ee77b80d7fece075 Binary files /dev/null and b/src/__pycache__/model.cpython-313.pyc differ diff --git a/src/__pycache__/predictor.cpython-310.pyc b/src/__pycache__/predictor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1b9fbf6257820e598ba577cdf033308932543ce Binary files /dev/null and b/src/__pycache__/predictor.cpython-310.pyc differ diff --git a/src/__pycache__/predictor.cpython-313.pyc b/src/__pycache__/predictor.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfa262a4d7a48512d7756c7fa56a1be793d35947 Binary files /dev/null and b/src/__pycache__/predictor.cpython-313.pyc differ diff --git a/src/__pycache__/predictor_multitask.cpython-310.pyc b/src/__pycache__/predictor_multitask.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aeb055f1f6897ca626837af4c006fe297dd2f23 Binary files /dev/null and b/src/__pycache__/predictor_multitask.cpython-310.pyc differ diff --git a/src/__pycache__/predictor_multitask.cpython-313.pyc b/src/__pycache__/predictor_multitask.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4914c85ad49d905504cbf20ab23d961a95d8b742 Binary files /dev/null and b/src/__pycache__/predictor_multitask.cpython-313.pyc differ diff --git a/src/__pycache__/predictor_router.cpython-310.pyc b/src/__pycache__/predictor_router.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83ac163aa7ce319ef05c77da5651a1d55ec63110 Binary files /dev/null and b/src/__pycache__/predictor_router.cpython-310.pyc differ diff --git a/src/__pycache__/predictor_router.cpython-313.pyc b/src/__pycache__/predictor_router.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74f3536cf285d466d8ecf0c451e36a5209339c0d Binary files /dev/null and b/src/__pycache__/predictor_router.cpython-313.pyc differ diff --git a/src/__pycache__/sascorer.cpython-310.pyc b/src/__pycache__/sascorer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67b436aa8c42bb174155a2639ada271170131662 Binary files /dev/null and b/src/__pycache__/sascorer.cpython-310.pyc differ diff --git a/src/__pycache__/sascorer.cpython-313.pyc b/src/__pycache__/sascorer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f493c07fb42cf909bf5abb543cb5c8262a4ce063 Binary files /dev/null and b/src/__pycache__/sascorer.cpython-313.pyc differ diff --git a/src/__pycache__/streamlit_app.cpython-313.pyc b/src/__pycache__/streamlit_app.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2d27a9bb97354cb7dd674f9ffd39ed4eba40b01 Binary files /dev/null and b/src/__pycache__/streamlit_app.cpython-313.pyc differ diff --git a/src/__pycache__/ui_style.cpython-310.pyc b/src/__pycache__/ui_style.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d87cd244a80b75afd0b6fbb21961a6760b0e31d1 Binary files /dev/null and b/src/__pycache__/ui_style.cpython-310.pyc differ diff --git a/src/__pycache__/ui_style.cpython-313.pyc b/src/__pycache__/ui_style.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7daa22388d1451931e7386c97c4d419c2633be2d Binary files /dev/null and b/src/__pycache__/ui_style.cpython-313.pyc differ diff --git a/src/__pycache__/utils.cpython-310.pyc b/src/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eedf83ede7554c70d62ecd244a2cd59098a7291 Binary files /dev/null and b/src/__pycache__/utils.cpython-310.pyc differ diff --git a/src/__pycache__/utils.cpython-313.pyc b/src/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..591c7abe24c92cf7ece9c275eb5235aa5bcf5531 Binary files /dev/null and b/src/__pycache__/utils.cpython-313.pyc differ diff --git a/src/conv.py b/src/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bd45c23d6754059a21e28e585acff033014e1c --- /dev/null +++ b/src/conv.py @@ -0,0 +1,258 @@ +# conv.py +# Clean, dependency-light graph encoder blocks for molecular GNNs. +# - Single source of truth for convolution choices: "gine", "gin", "gcn" +# - Edge attributes are supported for "gine" (recommended for chemistry) +# - No duplication with PyG built-ins; everything wraps torch_geometric.nn +# - Consistent encoder API: GNNEncoder(...).forward(x, edge_index, edge_attr, batch) -> graph embedding [B, emb_dim] + +from __future__ import annotations +from typing import Literal, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import ( + GINEConv, + GINConv, + GCNConv, + global_mean_pool, + global_add_pool, + global_max_pool, +) + + +def get_activation(name: str) -> nn.Module: + name = name.lower() + if name == "relu": + return nn.ReLU() + if name == "gelu": + return nn.GELU() + if name == "silu": + return nn.SiLU() + if name in ("leaky_relu", "lrelu"): + return nn.LeakyReLU(0.1) + raise ValueError(f"Unknown activation: {name}") + + +class MLP(nn.Module): + """Small MLP used inside GNN layers and projections.""" + def __init__( + self, + in_dim: int, + hidden_dim: int, + out_dim: int, + num_layers: int = 2, + act: str = "relu", + dropout: float = 0.0, + bias: bool = True, + ): + super().__init__() + assert num_layers >= 1 + layers: list[nn.Module] = [] + dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] + for i in range(len(dims) - 1): + layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias)) + if i < len(dims) - 2: + layers.append(get_activation(act)) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class NodeProjector(nn.Module): + """Projects raw node features to model embedding size.""" + def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"): + super().__init__() + if in_dim_node == emb_dim: + self.proj = nn.Identity() + else: + self.proj = nn.Sequential( + nn.Linear(in_dim_node, emb_dim), + get_activation(act), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.proj(x) + + +class EdgeProjector(nn.Module): + """Projects raw edge attributes to model embedding size for GINE.""" + def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"): + super().__init__() + if in_dim_edge <= 0: + raise ValueError("in_dim_edge must be > 0 when using edge attributes") + self.proj = nn.Sequential( + nn.Linear(in_dim_edge, emb_dim), + get_activation(act), + ) + + def forward(self, e: torch.Tensor) -> torch.Tensor: + return self.proj(e) + + +class GNNEncoder(nn.Module): + """ + Backbone GNN with selectable conv type. + + gnn_type: + - "gine": chemistry-ready, uses edge_attr (recommended) + - "gin" : ignores edge_attr, strong node MPNN + - "gcn" : ignores edge_attr, fast spectral conv + norm: "batch" | "layer" | "none" + readout: "mean" | "sum" | "max" + """ + + def __init__( + self, + in_dim_node: int, + emb_dim: int, + num_layers: int = 5, + gnn_type: Literal["gine", "gin", "gcn"] = "gine", + in_dim_edge: int = 0, + act: str = "relu", + dropout: float = 0.0, + residual: bool = True, + norm: Literal["batch", "layer", "none"] = "batch", + readout: Literal["mean", "sum", "max"] = "mean", + ): + super().__init__() + assert num_layers >= 1 + + self.gnn_type = gnn_type.lower() + self.emb_dim = emb_dim + self.num_layers = num_layers + self.residual = residual + self.dropout_p = float(dropout) + self.readout = readout.lower() + + self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act) + self.edge_proj: Optional[EdgeProjector] = None + + if self.gnn_type == "gine": + if in_dim_edge <= 0: + raise ValueError( + "gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type." + ) + self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act) + + # Build conv stack + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + + for _ in range(num_layers): + if self.gnn_type == "gine": + # edge_attr must be projected to emb_dim + nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0) + conv = GINEConv(nn_mlp) + elif self.gnn_type == "gin": + nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0) + conv = GINConv(nn_mlp) + elif self.gnn_type == "gcn": + conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True) + else: + raise ValueError(f"Unknown gnn_type: {gnn_type}") + self.convs.append(conv) + + if norm == "batch": + self.norms.append(nn.BatchNorm1d(emb_dim)) + elif norm == "layer": + self.norms.append(nn.LayerNorm(emb_dim)) + elif norm == "none": + self.norms.append(nn.Identity()) + else: + raise ValueError(f"Unknown norm: {norm}") + + self.act = get_activation(act) + + def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: + if self.readout == "mean": + return global_mean_pool(x, batch) + if self.readout == "sum": + return global_add_pool(x, batch) + if self.readout == "max": + return global_max_pool(x, batch) + raise ValueError(f"Unknown readout: {self.readout}") + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: Optional[torch.Tensor], + batch: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + Returns a graph-level embedding of shape [B, emb_dim]. + If batch is None, assumes a single graph and creates a zero batch vector. + """ + if batch is None: + batch = x.new_zeros(x.size(0), dtype=torch.long) + + # Project features (ensure float dtype) + x = x.float() + x = self.node_proj(x) + + e = None + if self.gnn_type == "gine": + if edge_attr is None: + raise ValueError("GINE requires edge_attr, but got None.") + e = self.edge_proj(edge_attr.float()) + + # Message passing + h = x + for conv, norm in zip(self.convs, self.norms): + if self.gnn_type == "gcn": + h_next = conv(h, edge_index) # GCNConv ignores edge_attr + elif self.gnn_type == "gin": + h_next = conv(h, edge_index) # GINConv ignores edge_attr + else: # gine + h_next = conv(h, edge_index, e) + + h_next = norm(h_next) + h_next = self.act(h_next) + + if self.residual and h_next.shape == h.shape: + h = h + h_next + else: + h = h_next + + if self.dropout_p > 0: + h = F.dropout(h, p=self.dropout_p, training=self.training) + + g = self._readout(h, batch) + return g # [B, emb_dim] + + +def build_gnn_encoder( + in_dim_node: int, + emb_dim: int, + num_layers: int = 5, + gnn_type: Literal["gine", "gin", "gcn"] = "gine", + in_dim_edge: int = 0, + act: str = "relu", + dropout: float = 0.0, + residual: bool = True, + norm: Literal["batch", "layer", "none"] = "batch", + readout: Literal["mean", "sum", "max"] = "mean", +) -> GNNEncoder: + """ + Factory to create a GNNEncoder with a consistent, minimal API. + Prefer calling this from model.py so encoder construction is centralized. + """ + return GNNEncoder( + in_dim_node=in_dim_node, + emb_dim=emb_dim, + num_layers=num_layers, + gnn_type=gnn_type, + in_dim_edge=in_dim_edge, + act=act, + dropout=dropout, + residual=residual, + norm=norm, + readout=readout, + ) + + +__all__ = ["GNNEncoder", "build_gnn_encoder"] diff --git a/src/data_builder.py b/src/data_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..695aa3c71c7bd8cc22ef29a4ba3c18aadd6ddb97 --- /dev/null +++ b/src/data_builder.py @@ -0,0 +1,818 @@ +# data_builder.py +from __future__ import annotations + +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Sequence +import json +import warnings + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import Dataset +from torch_geometric.data import Data + +# RDKit is required +from rdkit import Chem +from rdkit.Chem.rdchem import HybridizationType, BondType, BondStereo + +# --------------------------------------------------------- +# Fidelity handling +# --------------------------------------------------------- + +FID_PRIORITY = ["exp", "dft", "md", "gc"] # internal lower-case canonical order + + +def _norm_fid(fid: str) -> str: + return fid.strip().lower() + + +def _ensure_targets_order(requested: Sequence[str]) -> List[str]: + seen = set() + ordered = [] + for t in requested: + key = t.strip() + if key in seen: + continue + seen.add(key) + ordered.append(key) + return ordered + + +# --------------------------------------------------------- +# RDKit featurization +# --------------------------------------------------------- + +_ATOMS = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"] +_ATOM2IDX = {s: i for i, s in enumerate(_ATOMS)} +_HYBS = [HybridizationType.SP, HybridizationType.SP2, HybridizationType.SP3, HybridizationType.SP3D, HybridizationType.SP3D2] +_HYB2IDX = {h: i for i, h in enumerate(_HYBS)} +_BOND_STEREOS = [ + BondStereo.STEREONONE, + BondStereo.STEREOANY, + BondStereo.STEREOZ, + BondStereo.STEREOE, + BondStereo.STEREOCIS, + BondStereo.STEREOTRANS, +] +_STEREO2IDX = {s: i for i, s in enumerate(_BOND_STEREOS)} + + +def _one_hot(index: int, size: int) -> List[float]: + v = [0.0] * size + if 0 <= index < size: + v[index] = 1.0 + return v + + +def atom_features(atom: Chem.Atom) -> List[float]: + # Element one-hot with "other" + elem_idx = _ATOM2IDX.get(atom.GetSymbol(), None) + elem_oh = _one_hot(elem_idx if elem_idx is not None else len(_ATOMS), len(_ATOMS) + 1) + + # Degree one-hot up to 5 (bucket 5+) + deg = min(int(atom.GetDegree()), 5) + deg_oh = _one_hot(deg, 6) + + # Formal charge one-hot in [-2,-1,0,+1,+2] + fc = max(-2, min(2, int(atom.GetFormalCharge()))) + fc_oh = _one_hot(fc + 2, 5) + + # Aromatic, in ring flags + aromatic = [1.0 if atom.GetIsAromatic() else 0.0] + in_ring = [1.0 if atom.IsInRing() else 0.0] + + # Hybridization one-hot with "other" + hyb_idx = _HYB2IDX.get(atom.GetHybridization(), None) + hyb_oh = _one_hot(hyb_idx if hyb_idx is not None else len(_HYBS), len(_HYBS) + 1) + + # Implicit H count capped at 4 + imp_h = min(int(atom.GetTotalNumHs(includeNeighbors=True)), 4) + imp_h_oh = _one_hot(imp_h, 5) + + # length: 11+6+5+1+1+6+5 = 35 (element has 11 buckets incl. "other") + feats = elem_oh + deg_oh + fc_oh + aromatic + in_ring + hyb_oh + imp_h_oh + return feats + + +def bond_features(bond: Chem.Bond) -> List[float]: + bt = bond.GetBondType() + single = 1.0 if bt == BondType.SINGLE else 0.0 + double = 1.0 if bt == BondType.DOUBLE else 0.0 + triple = 1.0 if bt == BondType.TRIPLE else 0.0 + aromatic = 1.0 if bt == BondType.AROMATIC else 0.0 + conj = 1.0 if bond.GetIsConjugated() else 0.0 + in_ring = 1.0 if bond.IsInRing() else 0.0 + stereo_oh = _one_hot(_STEREO2IDX.get(bond.GetStereo(), 0), len(_BOND_STEREOS)) + # length: 4 + 1 + 1 + 6 = 12 + return [single, double, triple, aromatic, conj, in_ring] + stereo_oh + + +def featurize_smiles(smiles: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"RDKit failed to parse SMILES: {smiles}") + + # Nodes + x = torch.tensor([atom_features(a) for a in mol.GetAtoms()], dtype=torch.float32) + + # Edges (bidirectional) + rows, cols, eattr = [], [], [] + for b in mol.GetBonds(): + i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() + bf = bond_features(b) + rows.extend([i, j]) + cols.extend([j, i]) + eattr.extend([bf, bf]) + + if not rows: + # single-atom molecules, add a dummy self-loop edge + rows, cols = [0], [0] + eattr = [[0.0] * 12] + + edge_index = torch.tensor([rows, cols], dtype=torch.long) + edge_attr = torch.tensor(eattr, dtype=torch.float32) + return x, edge_index, edge_attr + + +# --------------------------------------------------------- +# CSV discovery and reading +# --------------------------------------------------------- + +def discover_target_fid_csvs( + root: Path, + targets: Sequence[str], + fidelities: Sequence[str], +) -> Dict[tuple[str, str], Path]: + """ + Discover CSV files for (target, fidelity) pairs. + + Supported layouts (case-insensitive): + + 1) {root}/{fid}/{target}.csv + e.g. datafull/MD/SHEAR.csv, datafull/exp/cp.csv + + 2) {root}/{target}_{fid}.csv + e.g. datafull/SHEAR_MD.csv, datafull/cp_exp.csv + + Matching is STRICT: + - target and fid must appear as full '_' tokens in the stem + - no substring matching, so 'he' will NOT match 'shear_md.csv' + """ + root = Path(root) + targets = _ensure_targets_order(targets) + fids_lc = [_norm_fid(f) for f in fidelities] + + # Collect all CSVs under root + all_paths = list(root.rglob("*.csv")) + + # Pre-index: (parent_name_lower, stem_lower, tokens_lower) + indexed = [] + for p in all_paths: + parent = p.parent.name.lower() + stem = p.stem.lower() # filename without extension + tokens = stem.split("_") + tokens_l = [t.lower() for t in tokens] + indexed.append((p, parent, stem, tokens_l)) + + mapping: Dict[tuple[str, str], Path] = {} + + for fid in fids_lc: + fid_l = fid.strip().lower() + + for tgt in targets: + tgt_l = tgt.strip().lower() + + # ---- 1) Prefer explicit folder layout: {root}/{fid}/{target}.csv ---- + # parent == fid AND stem == target (case-insensitive) + folder_matches = [ + p for (p, parent, stem, tokens_l) in indexed + if parent == fid_l and stem == tgt_l + ] + if folder_matches: + # If you ever get more than one, it’s a config problem + if len(folder_matches) > 1: + warnings.warn( + f"[discover_target_fid_csvs] Multiple matches for " + f"target='{tgt}' fid='{fid}' under folder layout: " + + ", ".join(str(p) for p in folder_matches) + ) + mapping[(tgt, fid)] = folder_matches[0] + continue + + # ---- 2) Fallback: {target}_{fid}.csv anywhere under root ---- + # require BOTH tgt and fid as full '_' tokens + token_matches = [ + p for (p, parent, stem, tokens_l) in indexed + if (tgt_l in tokens_l) and (fid_l in tokens_l) + ] + + if token_matches: + if len(token_matches) > 1: + warnings.warn( + f"[discover_target_fid_csvs] Multiple token matches for " + f"target='{tgt}' fid='{fid}': " + + ", ".join(str(p) for p in token_matches) + ) + mapping[(tgt, fid)] = token_matches[0] + continue + + # If neither layout exists, we simply do not add (tgt, fid) to mapping. + # build_long_table will just skip that combination. + # You can enable a warning if you want: + # warnings.warn(f"[discover_target_fid_csvs] No CSV for target='{tgt}', fid='{fid}'") + + return mapping + + +def read_target_csv(path: Path, target: str) -> pd.DataFrame: + """ + Accepts: + - 'smiles' column (case-insensitive) + - value column named '{target}' or one of ['value','y' or lower-case target] + Deduplicates by SMILES with mean. + """ + df = pd.read_csv(path) + + # smiles column + smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None) + if smiles_col is None: + raise ValueError(f"{path} must contain a 'smiles' column.") + df = df.rename(columns={smiles_col: "smiles"}) + + # value column + val_col = None + if target in df.columns: + val_col = target + else: + for c in df.columns: + if c.lower() in ("value", "y", target.lower()): + val_col = c + break + if val_col is None: + raise ValueError(f"{path} must contain a '{target}' column or one of ['value','y'].") + + df = df[["smiles", val_col]].copy() + df = df.dropna(subset=[val_col]) + df[val_col] = pd.to_numeric(df[val_col], errors="coerce") + df = df.dropna(subset=[val_col]) + + # Deduplicate SMILES by mean + if df.duplicated(subset=["smiles"]).any(): + warnings.warn(f"[data_builder] Duplicates by SMILES in {path}. Averaging duplicates.") + df = df.groupby("smiles", as_index=False)[val_col].mean() + + return df.rename(columns={val_col: target}) + + +def build_long_table(root: Path, targets: Sequence[str], fidelities: Sequence[str]) -> pd.DataFrame: + """ + Returns long-form table with columns: [smiles, fid, fid_idx, target, value] + """ + targets = _ensure_targets_order(targets) + fids_lc = [_norm_fid(f) for f in fidelities] + + mapping = discover_target_fid_csvs(root, targets, fidelities) + if not mapping: + raise FileNotFoundError(f"No CSVs found under {root} for the given targets and fidelities.") + + long_rows = [] + for (tgt, fid), path in mapping.items(): + df = read_target_csv(path, tgt) + df["fid"] = _norm_fid(fid) + df["target"] = tgt + df = df.rename(columns={tgt: "value"}) + long_rows.append(df[["smiles", "fid", "target", "value"]]) + + long = pd.concat(long_rows, axis=0, ignore_index=True) + + # attach fid index by priority + fid2idx = {f: i for i, f in enumerate(FID_PRIORITY)} + long["fid"] = long["fid"].str.lower() + unknown = sorted(set(long["fid"]) - set(fid2idx.keys())) + if unknown: + warnings.warn(f"[data_builder] Unknown fidelities found: {unknown}. Appending after known ones.") + start = len(fid2idx) + for i, f in enumerate(unknown): + fid2idx[f] = start + i + + long["fid_idx"] = long["fid"].map(fid2idx) + return long + + +def pivot_to_rows_by_smiles_fid(long: pd.DataFrame, targets: Sequence[str]) -> pd.DataFrame: + """ + Input: long table [smiles, fid, fid_idx, target, value] + Output: row-per-(smiles,fid) with wide columns for each target + """ + targets = _ensure_targets_order(targets) + wide = long.pivot_table(index=["smiles", "fid", "fid_idx"], columns="target", values="value", aggfunc="mean") + wide = wide.reset_index() + + for t in targets: + if t not in wide.columns: + wide[t] = np.nan + + cols = ["smiles", "fid", "fid_idx"] + list(targets) + return wide[cols] + + +# --------------------------------------------------------- +# Grouped split by SMILES and transforms/normalization +# --------------------------------------------------------- + +def grouped_split_by_smiles( + df_rows: pd.DataFrame, + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + uniq = df_rows["smiles"].drop_duplicates().values + rng = np.random.default_rng(seed) + uniq = rng.permutation(uniq) + + n = len(uniq) + n_test = int(round(n * test_ratio)) + n_val = int(round(n * val_ratio)) + + test_smiles = set(uniq[:n_test]) + val_smiles = set(uniq[n_test:n_test + n_val]) + train_smiles = set(uniq[n_test + n_val:]) + + train_idx = df_rows.index[df_rows["smiles"].isin(train_smiles)].to_numpy() + val_idx = df_rows.index[df_rows["smiles"].isin(val_smiles)].to_numpy() + test_idx = df_rows.index[df_rows["smiles"].isin(test_smiles)].to_numpy() + return train_idx, val_idx, test_idx + + +# ---------------- Enhanced TargetScaler with per-task transforms ---------------- + +class TargetScaler: + """ + Per-task transform + standardization fitted on the training split only. + + - transforms[t] in {"identity","log10"} + - eps[t] is added before log for numerical safety (only used if transforms[t]=="log10") + - mean/std are computed in the *transformed* domain + """ + def __init__(self, transforms: Optional[Sequence[str]] = None, eps: Optional[Sequence[float] | torch.Tensor] = None): + self.mean: Optional[torch.Tensor] = None # [T] (transformed domain) + self.std: Optional[torch.Tensor] = None # [T] (transformed domain) + self.transforms: List[str] = [str(t).lower() for t in transforms] if transforms is not None else [] + if eps is None: + self.eps: Optional[torch.Tensor] = None + else: + self.eps = torch.as_tensor(eps, dtype=torch.float32) + self._tiny = 1e-12 + + def _ensure_cfg(self, T: int): + if not self.transforms or len(self.transforms) != T: + self.transforms = ["identity"] * T + if self.eps is None or self.eps.numel() != T: + self.eps = torch.zeros(T, dtype=torch.float32) + + def _forward_transform_only(self, y: torch.Tensor) -> torch.Tensor: + """ + Apply per-task transforms *before* standardization. + y: [N, T] in original units. Returns transformed y_tf in same shape. + """ + out = y.clone() + T = out.size(1) + self._ensure_cfg(T) + for t in range(T): + if self.transforms[t] == "log10": + out[:, t] = torch.log10(torch.clamp(out[:, t] + self.eps[t], min=self._tiny)) + return out + + def _inverse_transform_only(self, y_tf: torch.Tensor) -> torch.Tensor: + """ + Inverse the per-task transform (no standardization here). + y_tf: [N, T] in transformed units. + """ + out = y_tf.clone() + T = out.size(1) + self._ensure_cfg(T) + for t in range(T): + if self.transforms[t] == "log10": + out[:, t] = (10.0 ** out[:, t]) - self.eps[t] + return out + + def fit(self, y: torch.Tensor, mask: torch.Tensor): + """ + y: [N, T] original units; mask: [N, T] bool + Chooses eps automatically if not provided; mean/std computed in transformed space. + """ + T = y.size(1) + self._ensure_cfg(T) + + if self.eps is None or self.eps.numel() != T: + # Auto epsilon: 0.1 * min positive per task (robust) + eps_vals: List[float] = [] + y_np = y.detach().cpu().numpy() + m_np = mask.detach().cpu().numpy().astype(bool) + for t in range(T): + if self.transforms[t] != "log10": + eps_vals.append(0.0) + continue + vals = y_np[m_np[:, t], t] + pos = vals[vals > 0] + if pos.size == 0: + eps_vals.append(1e-8) + else: + eps_vals.append(0.1 * float(max(np.min(pos), 1e-8))) + self.eps = torch.tensor(eps_vals, dtype=torch.float32) + + y_tf = self._forward_transform_only(y) + eps = 1e-8 + y_masked = torch.where(mask, y_tf, torch.zeros_like(y_tf)) + counts = mask.sum(dim=0).clamp_min(1) + mean = y_masked.sum(dim=0) / counts + var = ((torch.where(mask, y_tf - mean, torch.zeros_like(y_tf))) ** 2).sum(dim=0) / counts + std = torch.sqrt(var + eps) + self.mean, self.std = mean, std + + def transform(self, y: torch.Tensor) -> torch.Tensor: + y_tf = self._forward_transform_only(y) + return (y_tf - self.mean) / self.std + + def inverse(self, y_std: torch.Tensor) -> torch.Tensor: + """ + Inverse standardization + inverse transform → original units. + y_std: [N, T] in standardized-transformed space + """ + y_tf = y_std * self.std + self.mean + return self._inverse_transform_only(y_tf) + + def state_dict(self) -> Dict[str, torch.Tensor | List[str]]: + return { + "mean": self.mean, + "std": self.std, + "transforms": self.transforms, + "eps": self.eps, + } + + def load_state_dict(self, state: Dict[str, torch.Tensor | List[str]]): + self.mean = state["mean"] + self.std = state["std"] + self.transforms = [str(t) for t in state.get("transforms", [])] + eps = state.get("eps", None) + self.eps = torch.as_tensor(eps, dtype=torch.float32) if eps is not None else None + + +def auto_select_task_transforms( + y_train: torch.Tensor, # [N, T] original units (train split only) + mask_train: torch.Tensor, # [N, T] bool + task_names: Sequence[str], + *, + min_pos_frac: float = 0.95, # ≥95% of labels positive + orders_threshold: float = 2.0, # ≥2 orders of magnitude between p95 and p5 + tiny: float = 1e-12, +) -> tuple[List[str], torch.Tensor]: + """ + Decide per-task transform: "log10" if (mostly-positive AND large dynamic range), else "identity". + Returns (transforms, eps_vector) where eps is only used for log tasks. + """ + Y = y_train.detach().cpu().numpy() + M = mask_train.detach().cpu().numpy().astype(bool) + + transforms: List[str] = [] + eps_vals: List[float] = [] + + for t in range(Y.shape[1]): + yt = Y[M[:, t], t] + if yt.size == 0: + transforms.append("identity") + eps_vals.append(0.0) + continue + + pos_frac = (yt > 0).mean() + p5 = float(np.percentile(yt, 5)) + p95 = float(np.percentile(yt, 95)) + denom = max(p5, tiny) + dyn_orders = float(np.log10(max(p95 / denom, 1.0))) + use_log = (pos_frac >= min_pos_frac) and (dyn_orders >= orders_threshold) + + if use_log: + pos_vals = yt[yt > 0] + if pos_vals.size == 0: + eps_vals.append(1e-8) + else: + eps_vals.append(0.1 * float(max(np.min(pos_vals), 1e-8))) + transforms.append("log10") + else: + transforms.append("identity") + eps_vals.append(0.0) + + return transforms, torch.tensor(eps_vals, dtype=torch.float32) + + +# --------------------------------------------------------- +# Dataset +# --------------------------------------------------------- + +class MultiFidelityMoleculeDataset(Dataset): + """ + Each item is a PyG Data with: + - x: [N_nodes, F_node] + - edge_index: [2, N_edges] + - edge_attr: [N_edges, F_edge] + - y: [T] normalized targets (zeros where missing) + - y_mask: [T] bool mask of present targets + - fid_idx: [1] long + - .smiles and .fid_str added for debugging + + Targets are kept in the exact order provided by the user. + """ + def __init__( + self, + rows: pd.DataFrame, + targets: Sequence[str], + scaler: Optional[TargetScaler], + smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + ): + super().__init__() + self.rows = rows.reset_index(drop=True).copy() + self.targets = _ensure_targets_order(targets) + self.scaler = scaler + self.smiles_graph_cache = smiles_graph_cache + + # Build y and mask tensors + ys, masks = [], [] + for _, r in self.rows.iterrows(): + yv, mv = [], [] + for t in self.targets: + v = r[t] + if pd.isna(v): + yv.append(np.nan) + mv.append(False) + else: + yv.append(float(v)) + mv.append(True) + ys.append(yv) + masks.append(mv) + + y = torch.tensor(np.array(ys, dtype=np.float32)) # [N, T] + mask = torch.tensor(np.array(masks, dtype=np.bool_)) + + if scaler is not None and scaler.mean is not None: + y_norm = torch.where(mask, scaler.transform(y), torch.zeros_like(y)) + else: + y_norm = y + + self.y = y_norm + self.mask = mask + + # Input dims + any_smiles = self.rows.iloc[0]["smiles"] + x0, _, e0 = smiles_graph_cache[any_smiles] + self.in_dim_node = x0.shape[1] + self.in_dim_edge = e0.shape[1] + + # Fidelity metadata for reference (local indexing in this dataset) + self.fids = sorted( + self.rows["fid"].str.lower().unique().tolist(), + key=lambda f: (FID_PRIORITY + [f]).index(f) if f in FID_PRIORITY else len(FID_PRIORITY), + ) + self.fid2idx = {f: i for i, f in enumerate(self.fids)} + self.rows["fid_idx_local"] = self.rows["fid"].str.lower().map(self.fid2idx) + + def __len__(self) -> int: + return len(self.rows) + + def __getitem__(self, idx: int) -> Data: + idx = int(idx) + r = self.rows.iloc[idx] + smi = r["smiles"] + + x, edge_index, edge_attr = self.smiles_graph_cache[smi] + # Ensure [1, T] so batches become [B, T] + y_i = self.y[idx].clone().unsqueeze(0) # [1, T] + m_i = self.mask[idx].clone().unsqueeze(0) # [1, T] + fid_idx = int(r["fid_idx_local"]) + + d = Data( + x=x.clone(), + edge_index=edge_index.clone(), + edge_attr=edge_attr.clone(), + y=y_i, + y_mask=m_i, + fid_idx=torch.tensor([fid_idx], dtype=torch.long), + ) + d.smiles = smi + d.fid_str = r["fid"] + return d + + +def subsample_train_indices( + rows: pd.DataFrame, + train_idx: np.ndarray, + *, + target: Optional[str], + fidelity: Optional[str], + pct: float = 1.0, + seed: int = 137, +) -> np.ndarray: + """ + Return a filtered train_idx that keeps only a 'pct' fraction (0] + """ + if target is None or fidelity is None or pct >= 0.999: + return train_idx + + if target not in rows.columns: + return train_idx + + fid_lc = fidelity.strip().lower() + + # Identify TRAIN rows in the specified block: matching fid and having a label for 'target' + train_rows = rows.iloc[train_idx] + block_mask = (train_rows["fid"].str.lower() == fid_lc) & (~train_rows[target].isna()) + if not bool(block_mask.any()): + return train_idx # nothing to subsample + + # Sample by unique SMILES (stable & grouped) + smiles_all = pd.Index(train_rows.loc[block_mask, "smiles"].unique()) + n_all = len(smiles_all) + if n_all == 0: + return train_idx + + if pct <= 0.0: + pct = 0.0001 + n_keep = max(1, int(round(pct * n_all))) + + rng = np.random.RandomState(int(seed)) + smiles_sorted = np.array(sorted(smiles_all.tolist())) + keep_smiles = set(rng.choice(smiles_sorted, size=n_keep, replace=False).tolist()) + + # Keep all non-block rows; within block keep selected SMILES + keep_mask_local = (~block_mask) | (train_rows["smiles"].isin(keep_smiles)) + kept_train_idx = train_rows.index[keep_mask_local].to_numpy() + return kept_train_idx + + +# --------------------------------------------------------- +# High-level builder +# --------------------------------------------------------- + +def build_dataset_from_dir( + root_dir: str | Path, + targets: Sequence[str], + fidelities: Sequence[str] = ("exp", "dft", "md", "gc"), + val_ratio: float = 0.1, + test_ratio: float = 0.1, + seed: int = 42, + save_splits_path: Optional[str | Path] = None, + # Optional subsampling of a (target, fidelity) block in TRAIN + subsample_target: Optional[str] = None, + subsample_fidelity: Optional[str] = None, + subsample_pct: float = 1.0, + subsample_seed: int = 137, + # -------- NEW: auto/explicit log transforms -------- + auto_log: bool = True, + log_orders_threshold: float = 2.0, + log_min_pos_frac: float = 0.95, + explicit_log_targets: Optional[Sequence[str]] = None, # e.g. ["permeability"] +) -> tuple[MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, TargetScaler]: + """ + Returns train_ds, val_ds, test_ds, scaler. + + - Discovers CSVs for requested targets and fidelities + - Builds a row-per-(smiles,fid) table with columns for each target + - Splits by unique SMILES to avoid leakage across fidelity or targets + - Fits transform+normalization on the training split only, applies to val/test + - Builds RDKit graphs once per unique SMILES and reuses them + + NEW: + - Auto per-task transform selection ("log10" vs "identity") by criteria + - Optional explicit override via explicit_log_targets + """ + root = Path(root_dir) + targets = _ensure_targets_order(targets) + fids_lc = [_norm_fid(f) for f in fidelities] + + # Build long and pivot to rows + long = build_long_table(root, targets, fids_lc) + rows = pivot_to_rows_by_smiles_fid(long, targets) + + # Deterministic grouped split by SMILES + if save_splits_path is not None and Path(save_splits_path).exists(): + with open(save_splits_path, "r") as f: + split_obj = json.load(f) + train_smiles = set(split_obj["train_smiles"]) + val_smiles = set(split_obj["val_smiles"]) + test_smiles = set(split_obj["test_smiles"]) + train_idx = rows.index[rows["smiles"].isin(train_smiles)].to_numpy() + val_idx = rows.index[rows["smiles"].isin(val_smiles)].to_numpy() + test_idx = rows.index[rows["smiles"].isin(test_smiles)].to_numpy() + else: + train_idx, val_idx, test_idx = grouped_split_by_smiles(rows, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed) + if save_splits_path is not None: + split_obj = { + "train_smiles": rows.iloc[train_idx]["smiles"].drop_duplicates().tolist(), + "val_smiles": rows.iloc[val_idx]["smiles"].drop_duplicates().tolist(), + "test_smiles": rows.iloc[test_idx]["smiles"].drop_duplicates().tolist(), + "seed": seed, + "val_ratio": val_ratio, + "test_ratio": test_ratio, + } + Path(save_splits_path).parent.mkdir(parents=True, exist_ok=True) + with open(save_splits_path, "w") as f: + json.dump(split_obj, f, indent=2) + + # Build RDKit graphs once per unique SMILES + uniq_smiles = rows["smiles"].drop_duplicates().tolist() + smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + for smi in uniq_smiles: + try: + x, edge_index, edge_attr = featurize_smiles(smi) + smiles_graph_cache[smi] = (x, edge_index, edge_attr) + except Exception as e: + warnings.warn(f"[data_builder] Dropping SMILES due to RDKit parse error: {smi} ({e})") + + # Filter rows to those that featurized successfully + rows = rows[rows["smiles"].isin(smiles_graph_cache.keys())].reset_index(drop=True) + + # Re-map indices after filtering using smiles membership + train_idx = rows.index[rows["smiles"].isin(set(rows.iloc[train_idx]["smiles"]))].to_numpy() + val_idx = rows.index[rows["smiles"].isin(set(rows.iloc[val_idx]["smiles"]))].to_numpy() + test_idx = rows.index[rows["smiles"].isin(set(rows.iloc[test_idx]["smiles"]))].to_numpy() + + # Optional subsampling (train only) for a specific (target, fidelity) block + train_idx = subsample_train_indices( + rows, + train_idx, + target=subsample_target, + fidelity=subsample_fidelity, + pct=float(subsample_pct), + seed=int(subsample_seed), + ) + + # Fit scaler on training split only + def build_y_mask(df_slice: pd.DataFrame) -> tuple[torch.Tensor, torch.Tensor]: + ys, ms = [], [] + for _, r in df_slice.iterrows(): + yv, mv = [], [] + for t in targets: + v = r[t] + if pd.isna(v): + yv.append(np.nan) + mv.append(False) + else: + yv.append(float(v)) + mv.append(True) + ys.append(yv) + ms.append(mv) + y = torch.tensor(np.array(ys, dtype=np.float32)) + mask = torch.tensor(np.array(ms, dtype=np.bool_)) + return y, mask + + y_train, mask_train = build_y_mask(rows.iloc[train_idx]) + + # Decide transforms per task + if explicit_log_targets: + explicit_set = set(explicit_log_targets) + transforms = [("log10" if t in explicit_set else "identity") for t in targets] + eps_vec = None # will be auto-chosen in scaler.fit if not provided + elif auto_log: + transforms, eps_vec = auto_select_task_transforms( + y_train, + mask_train, + targets, + min_pos_frac=float(log_min_pos_frac), + orders_threshold=float(log_orders_threshold), + ) + else: + transforms, eps_vec = (["identity"] * len(targets), None) + + scaler = TargetScaler(transforms=transforms, eps=eps_vec) + scaler.fit(y_train, mask_train) + + # Build datasets + train_rows = rows.iloc[train_idx].reset_index(drop=True) + val_rows = rows.iloc[val_idx].reset_index(drop=True) + test_rows = rows.iloc[test_idx].reset_index(drop=True) + + train_ds = MultiFidelityMoleculeDataset(train_rows, targets, scaler, smiles_graph_cache) + val_ds = MultiFidelityMoleculeDataset(val_rows, targets, scaler, smiles_graph_cache) + test_ds = MultiFidelityMoleculeDataset(test_rows, targets, scaler, smiles_graph_cache) + + return train_ds, val_ds, test_ds, scaler + + +__all__ = [ + "build_dataset_from_dir", + "discover_target_fid_csvs", + "read_target_csv", + "build_long_table", + "pivot_to_rows_by_smiles_fid", + "grouped_split_by_smiles", + "TargetScaler", + "MultiFidelityMoleculeDataset", + "atom_features", + "bond_features", + "featurize_smiles", + "auto_select_task_transforms", +] diff --git a/src/discover_llm.py b/src/discover_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..dafa620a377eba4ac26ac3e659b90f7221942bf9 --- /dev/null +++ b/src/discover_llm.py @@ -0,0 +1,829 @@ +# src/discovery.py +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem +from . import sascorer + +# Reuse your canonicalizer if you want; otherwise keep local +def canonicalize_smiles(smiles: str) -> Optional[str]: + s = (smiles or "").strip() + if not s: + return None + m = Chem.MolFromSmiles(s) + if m is None: + return None + return Chem.MolToSmiles(m, canonical=True) + + +# ------------------------- +# Spec schema (minimal v0) +# ------------------------- +@dataclass +class DiscoverySpec: + dataset: List[str] # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"] + polyinfo: str # "POLYINFO_PROPERTY.parquet" + polyinfo_csv: str # "POLYINFO.csv" + + hard_constraints: Dict[str, Dict[str, float]] # { "tg": {"min": 400}, "tc": {"max": 0.3} } + objectives: List[Dict[str, str]] # [{"property":"cp","goal":"maximize"}, ...] + + max_pool: int = 200000 # legacy (kept for compatibility; aligned to pareto_max) + pareto_max: int = 50000 # cap points used for Pareto + diversity fingerprinting + max_candidates: int = 30 # final output size + max_pareto_fronts: int = 5 # how many Pareto layers to keep for candidate pool + min_distance: float = 0.30 # diversity threshold in Tanimoto distance + fingerprint: str = "morgan" # morgan only for now + random_seed: int = 7 + use_canonical_smiles: bool = True + use_full_data: bool = False + trust_weights: Dict[str, float] | None = None + selection_weights: Dict[str, float] | None = None + + +# ------------------------- +# Property metadata (local to discovery_llm) +# ------------------------- +PROPERTY_META: Dict[str, Dict[str, str]] = { + # Thermal + "tm": {"name": "Melting temperature", "unit": "K"}, + "tg": {"name": "Glass transition temperature", "unit": "K"}, + "td": {"name": "Thermal diffusivity", "unit": "m^2/s"}, + "tc": {"name": "Thermal conductivity", "unit": "W/m-K"}, + "cp": {"name": "Specific heat capacity", "unit": "J/kg-K"}, + # Mechanical + "young": {"name": "Young's modulus", "unit": "GPa"}, + "shear": {"name": "Shear modulus", "unit": "GPa"}, + "bulk": {"name": "Bulk modulus", "unit": "GPa"}, + "poisson": {"name": "Poisson ratio", "unit": "-"}, + # Transport + "visc": {"name": "Viscosity", "unit": "Pa-s"}, + "dif": {"name": "Diffusivity", "unit": "cm^2/s"}, + # Gas permeability + "phe": {"name": "He permeability", "unit": "Barrer"}, + "ph2": {"name": "H2 permeability", "unit": "Barrer"}, + "pco2": {"name": "CO2 permeability", "unit": "Barrer"}, + "pn2": {"name": "N2 permeability", "unit": "Barrer"}, + "po2": {"name": "O2 permeability", "unit": "Barrer"}, + "pch4": {"name": "CH4 permeability", "unit": "Barrer"}, + # Electronic / Optical + "alpha": {"name": "Polarizability", "unit": "a.u."}, + "homo": {"name": "HOMO energy", "unit": "eV"}, + "lumo": {"name": "LUMO energy", "unit": "eV"}, + "bandgap": {"name": "Band gap", "unit": "eV"}, + "mu": {"name": "Dipole moment", "unit": "Debye"}, + "etotal": {"name": "Total electronic energy", "unit": "eV"}, + "ri": {"name": "Refractive index", "unit": "-"}, + "dc": {"name": "Dielectric constant", "unit": "-"}, + "pe": {"name": "Permittivity", "unit": "-"}, + # Structural / Physical + "rg": {"name": "Radius of gyration", "unit": "A"}, + "rho": {"name": "Density", "unit": "g/cm^3"}, +} + + +# ------------------------- +# Column mapping +# ------------------------- +def mean_col(prop_key: str) -> str: + return f"mean_{prop_key.lower()}" + +def std_col(prop_key: str) -> str: + return f"std_{prop_key.lower()}" + + +def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]: + out: Dict[str, float] = {} + for k, v in defaults.items(): + try: + vv = float(weights.get(k, v)) + except Exception: + vv = float(v) + out[k] = max(0.0, vv) + s = float(sum(out.values())) + if s <= 0.0: + return defaults.copy() + return {k: float(v / s) for k, v in out.items()} + +def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec: + pareto_max = int(obj.get("pareto_max", 50000)) + return DiscoverySpec( + dataset=list(dataset_path), + polyinfo=polyinfo_path, + polyinfo_csv=polyinfo_csv_path, + hard_constraints=obj.get("hard_constraints", {}), + objectives=obj.get("objectives", []), + # Legacy field kept for compatibility; effectively collapsed to pareto_max. + max_pool=pareto_max, + pareto_max=pareto_max, + max_candidates=int(obj.get("max_candidates", 30)), + max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)), + min_distance=float(obj.get("min_distance", 0.30)), + fingerprint=str(obj.get("fingerprint", "morgan")), + random_seed=int(obj.get("random_seed", 7)), + use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)), + use_full_data=bool(obj.get("use_full_data", False)), + trust_weights=obj.get("trust_weights"), + selection_weights=obj.get("selection_weights"), + ) + +# ------------------------- +# Parquet loading (safe) +# ------------------------- +def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame: + """ + Load only requested columns from Parquet (critical for 1M rows). + Accepts a single path or a list of paths and concatenates rows. + """ + def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame: + available: list[str] + try: + import pyarrow.parquet as pq + + pf = pq.ParquetFile(fp) + available = [str(c) for c in pf.schema.names] + except Exception: + # If schema probing fails, fall back to direct read with requested columns. + return pd.read_parquet(fp, columns=req_cols) + + available_set = set(available) + lower_to_actual = {c.lower(): c for c in available} + + # Resolve requested names against actual parquet schema. + resolved: dict[str, str] = {} + for req in req_cols: + if req in available_set: + resolved[req] = req + continue + alt = lower_to_actual.get(str(req).lower()) + if alt is not None: + resolved[req] = alt + + use_cols = sorted(set(resolved.values())) + if not use_cols: + return pd.DataFrame(columns=req_cols) + + out = pd.read_parquet(fp, columns=use_cols) + for req in req_cols: + src = resolved.get(req) + if src is None: + out[req] = np.nan + elif src != req: + out[req] = out[src] + return out[req_cols] + + if isinstance(path, (list, tuple)): + frames = [_load_one(p, columns) for p in path] + if not frames: + return pd.DataFrame(columns=columns) + return pd.concat(frames, ignore_index=True) + return _load_one(path, columns) + + +def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]: + s = (smiles or "").strip() + if not s: + return None + if not use_canonical_smiles: + # Skip RDKit parsing entirely in fast mode. + return s + m = Chem.MolFromSmiles(s) + if m is None: + return None + if use_canonical_smiles: + return Chem.MolToSmiles(m, canonical=True) + return s + + +def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame: + """ + Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants). + Returns dataframe with index on smiles_key and columns polymer_name/polymer_class. + """ + df = pd.read_csv(polyinfo_csv_path) + + # normalize column names + cols = {c: c for c in df.columns} + # map typical names + if "SMILES" in cols: + df = df.rename(columns={"SMILES": "smiles"}) + elif "smiles" not in df.columns: + raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column") + + if "Polymer_Name" in df.columns: + df = df.rename(columns={"Polymer_Name": "polymer_name"}) + if "polymer_Name" in df.columns: + df = df.rename(columns={"polymer_Name": "polymer_name"}) + if "Polymer_Class" in df.columns: + df = df.rename(columns={"Polymer_Class": "polymer_class"}) + + if "polymer_name" not in df.columns: + df["polymer_name"] = pd.NA + if "polymer_class" not in df.columns: + df["polymer_class"] = pd.NA + + df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles)) + df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key") + df = df.set_index("smiles_key", drop=True) + return df[["polymer_name", "polymer_class"]] + + +# ------------------------- +# Pareto (2–3 objectives) +# ------------------------- +def pareto_front_mask(X: np.ndarray) -> np.ndarray: + """ + Returns mask for nondominated points. + X: (N, M), all objectives assumed to be minimized. + For maximize objectives, we invert before calling this. + """ + N = X.shape[0] + is_efficient = np.ones(N, dtype=bool) + for i in range(N): + if not is_efficient[i]: + continue + # any point that is <= in all dims and < in at least one dominates + dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1) + # if a point dominates i, mark i inefficient + if np.any(dominates): + is_efficient[i] = False + continue + # otherwise, i may dominate others + dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1) + is_efficient[dominated_by_i] = False + is_efficient[i] = True + return is_efficient + + +def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray: + """ + Returns layer index per point: 1 = Pareto front, 2 = second layer, ... + Unassigned points beyond max_layers get 0. + """ + N = X.shape[0] + layers = np.zeros(N, dtype=int) + remaining = np.arange(N) + + layer = 1 + while remaining.size > 0 and layer <= max_layers: + mask = pareto_front_mask(X[remaining]) + front_idx = remaining[mask] + layers[front_idx] = layer + remaining = remaining[~mask] + layer += 1 + return layers + + +def pareto_front_mask_chunked( + X: np.ndarray, + chunk_size: int = 100000, + progress_callback: Optional[Callable[[int, int], None]] = None, +) -> np.ndarray: + """ + Exact global Pareto front mask via chunk-local front reduction + global reconcile. + This is exact for front-1: + 1) compute exact local front within each chunk + 2) union local fronts + 3) compute exact front on the union + """ + N = X.shape[0] + if N <= chunk_size: + if progress_callback is not None: + progress_callback(1, 1) + return pareto_front_mask(X) + + local_front_idx = [] + total_chunks = (N + chunk_size - 1) // chunk_size + done_chunks = 0 + for start in range(0, N, chunk_size): + end = min(start + chunk_size, N) + idx = np.arange(start, end) + mask_local = pareto_front_mask(X[idx]) + local_front_idx.append(idx[mask_local]) + done_chunks += 1 + if progress_callback is not None: + progress_callback(done_chunks, total_chunks) + + if not local_front_idx: + return np.zeros(N, dtype=bool) + + reduced_idx = np.concatenate(local_front_idx) + reduced_mask = pareto_front_mask(X[reduced_idx]) + front_idx = reduced_idx[reduced_mask] + + out = np.zeros(N, dtype=bool) + out[front_idx] = True + return out + + +def pareto_layers_chunked( + X: np.ndarray, + max_layers: int = 10, + chunk_size: int = 100000, + progress_callback: Optional[Callable[[int, int, int], None]] = None, +) -> np.ndarray: + """ + Exact Pareto layers using repeated exact chunked front extraction. + """ + N = X.shape[0] + layers = np.zeros(N, dtype=int) + remaining = np.arange(N) + layer = 1 + + while remaining.size > 0 and layer <= max_layers: + def on_chunk(done: int, total: int) -> None: + if progress_callback is not None: + progress_callback(layer, done, total) + + mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk) + front_idx = remaining[mask] + layers[front_idx] = layer + remaining = remaining[~mask] + layer += 1 + + return layers + + +# ------------------------- +# Fingerprints & diversity +# ------------------------- +def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048): + m = Chem.MolFromSmiles(smiles) + if m is None: + return None + return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits) + +def tanimoto_distance(fp1, fp2) -> float: + return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2) + +def greedy_diverse_select( + smiles_list: List[str], + scores: np.ndarray, + max_k: int, + min_dist: float, +) -> List[int]: + """ + Greedy selection by descending score, enforcing min Tanimoto distance. + Returns indices into smiles_list. + """ + fps = [] + valid_idx = [] + for i, s in enumerate(smiles_list): + fp = morgan_fp(s) + if fp is not None: + fps.append(fp) + valid_idx.append(i) + + if not valid_idx: + return [] + + # rank candidates (higher score first) + order = np.argsort(-scores[valid_idx]) + selected_global = [] + selected_fps = [] + + for oi in order: + i = valid_idx[oi] + fp_i = fps[oi] # aligned with valid_idx + ok = True + for fp_j in selected_fps: + if tanimoto_distance(fp_i, fp_j) < min_dist: + ok = False + break + if ok: + selected_global.append(i) + selected_fps.append(fp_i) + if len(selected_global) >= max_k: + break + + return selected_global + + +# ------------------------- +# Trust score (lightweight, robust) +# ------------------------- +def internal_consistency_penalty(row: pd.Series) -> float: + """ + Very simple physics/validity checks. Penalty in [0,1]. + Adjust/add rules later. + """ + viol = 0 + total = 0 + + def chk(cond: bool): + nonlocal viol, total + total += 1 + if not cond: + viol += 1 + + # positivity checks if present + for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]: + c = mean_col(p) + if c in row.index and pd.notna(row[c]): + if p in ["bandgap", "tg", "tm"]: + chk(float(row[c]) >= 0.0) + else: + chk(float(row[c]) > 0.0) + + # Poisson ratio bounds if present + if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]): + v = float(row[mean_col("poisson")]) + chk(0.0 <= v <= 0.5) + + # Tg <= Tm if both present + if mean_col("tg") in row.index and mean_col("tm") in row.index: + if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]): + chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")])) + + if total == 0: + return 0.0 + return viol / total + + +def synthesizability_score(smiles: str) -> float: + """ + RDKit SA-score based synthesizability proxy in [0,1]. + SA-score is ~[1 (easy), 10 (hard)]. + We map: 1 -> 1.0, 10 -> 0.0 + """ + m = Chem.MolFromSmiles(smiles) + if m is None: + return 0.0 + + # Guard against unexpected scorer failures / None for edge-case molecules. + try: + sa_raw = sascorer.calculateScore(m) + except Exception: + return 0.0 + if sa_raw is None: + return 0.0 + + sa = float(sa_raw) # ~ 1..10 + s_syn = 1.0 - (sa - 1.0) / 9.0 # linear map to [0,1] + return float(np.clip(s_syn, 0.0, 1.0)) + + +def compute_trust_scores( + df: pd.DataFrame, + real_fps: List, + real_smiles: List[str], + trust_weights: Dict[str, float] | None = None, +) -> np.ndarray: + """ + Trust score in [0,1] (higher = more trustworthy / lower risk). + Components: + - distance to nearest real polymer (fingerprint distance) + - internal consistency penalty + - uncertainty penalty (if std columns exist) + - synthesizability + """ + N = len(df) + trust = np.zeros(N, dtype=float) + tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20} + tw = normalize_weights(trust_weights or {}, tw_defaults) + + # nearest-real distance (expensive if done naively) + # We do it only for the (small) post-filter set, which is safe. + smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon" + for i in range(N): + s = df.iloc[i][smiles_col] + fp = morgan_fp(s) + if fp is None or not real_fps: + d_real = 1.0 + else: + sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps) + d_real = 1.0 - float(max(sims)) # distance to nearest + + # internal consistency + pen_cons = internal_consistency_penalty(df.iloc[i]) + + # uncertainty: average normalized std for any std_* columns present + std_cols = [c for c in df.columns if c.startswith("std_")] + if std_cols: + std_vals = df.iloc[i][std_cols].astype(float) + std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna() + pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0 + else: + pen_unc = 0.0 + + # synthesizability heuristic + s_syn = synthesizability_score(s) + + # Combine (tunable weights) + # lower distance to real is better -> convert to score + s_real = 1.0 - np.clip(d_real, 0.0, 1.0) + + trust[i] = ( + tw["real"] * s_real + + tw["consistency"] * (1.0 - pen_cons) + + tw["uncertainty"] * (1.0 - pen_unc) + + tw["synth"] * s_syn + ) + + trust = np.clip(trust, 0.0, 1.0) + return trust + + +# ------------------------- +# Main pipeline +# ------------------------- +def run_discovery( + spec: DiscoverySpec, + progress_callback: Optional[Callable[[str, float], None]] = None, +) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]: + def report(step: str, pct: float) -> None: + if progress_callback is not None: + progress_callback(step, pct) + + rng = np.random.default_rng(spec.random_seed) + + # 1) Determine required columns + report("Preparing columns…", 0.02) + obj_props = [o["property"].lower() for o in spec.objectives] + cons_props = [p.lower() for p in spec.hard_constraints.keys()] + + needed_props = sorted(set(obj_props + cons_props)) + cols = ["SMILES"] + [mean_col(p) for p in needed_props] + + # include std columns if available (not required, but used for trust) + std_cols = [std_col(p) for p in needed_props] + cols += std_cols + + # 2) Load only needed columns + report("Loading data from parquet…", 0.05) + df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"]) + # normalize + if "SMILES" not in df.columns and "smiles" in df.columns: + df = df.rename(columns={"smiles": "SMILES"}) + normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…" + report(normalize_step, 0.10) + df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles)) + df = df.dropna(subset=["smiles_key"]).reset_index(drop=True) + + # 3) Hard constraints + report("Applying constraints…", 0.22) + for p, rule in spec.hard_constraints.items(): + p = p.lower() + c = mean_col(p) + if c not in df.columns: + # if missing, nothing can satisfy + df = df.iloc[0:0] + break + if "min" in rule: + df = df[df[c] >= float(rule["min"])] + if "max" in rule: + df = df[df[c] <= float(rule["max"])] + + n_after = len(df) + if n_after == 0: + empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0} + return df, empty_stats, pd.DataFrame() + + n_pool = len(df) + + # 5) Prepare objective matrix for Pareto + report("Building objective matrix…", 0.30) + # convert to minimization: maximize => negate + X = [] + resolved_objectives = [] + for o in spec.objectives: + prop = o["property"].lower() + goal = o["goal"].lower() + c = mean_col(prop) + if c not in df.columns: + continue + v = df[c].to_numpy(dtype=float) + if goal == "maximize": + v = -v + X.append(v) + resolved_objectives.append({"property": prop, "goal": goal}) + if not X: + # Fallback to first available mean_* column to keep pipeline runnable. + fallback_col = next((c for c in df.columns if str(c).startswith("mean_")), None) + if fallback_col is None: + empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0} + return df.iloc[0:0], empty_stats, pd.DataFrame() + X = [df[fallback_col].to_numpy(dtype=float) * -1.0] + resolved_objectives = [{"property": fallback_col.replace("mean_", ""), "goal": "maximize"}] + X = np.stack(X, axis=1) # (N, M) + obj_props = [o["property"] for o in resolved_objectives] + + # Pareto cap before computing layers (optional safety) + if spec.use_full_data: + report("Using full dataset (no Pareto cap)…", 0.35) + elif len(df) > spec.pareto_max: + idx = rng.choice(len(df), size=spec.pareto_max, replace=False) + df = df.iloc[idx].reset_index(drop=True) + X = X[idx] + + # 6) Pareto layers (only 5 layers needed for candidate pool) + report("Computing Pareto layers…", 0.40) + pareto_start = 0.40 + pareto_end = 0.54 + max_layers_for_pool = max(1, int(spec.max_pareto_fronts)) + pareto_chunk_ref = {"chunks_per_layer": None} + + def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None: + if pareto_chunk_ref["chunks_per_layer"] is None: + pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks)) + ref_chunks = pareto_chunk_ref["chunks_per_layer"] + total_units = max_layers_for_pool * ref_chunks + done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks) + pareto_pct = int(round(100.0 * done_units / max(1, total_units))) + + layer_progress = done_chunks / max(1, total_chunks) + overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool + pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall)) + report( + f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})", + pct, + ) + + layers = pareto_layers_chunked( + X, + max_layers=max_layers_for_pool, + chunk_size=100000, + progress_callback=on_pareto_chunk, + ) + report("Computing Pareto layers…", pareto_end) + df["pareto_layer"] = layers + plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy() + plot_df = plot_df.rename(columns={"smiles_key": "SMILES"}) + + # Keep first few layers as candidate pool (avoid huge set) + cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy() + if cand.empty: + cand = df[df["pareto_layer"] == 1].copy() + cand = cand.reset_index(drop=True) + n_pareto = len(cand) + + # 7) Load real polymer metadata and fingerprints (from POLYINFO.csv) + report("Loading POLYINFO index…", 0.55) + polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles) + real_smiles = polyinfo.index.to_list() + + report("Building real-polymer fingerprints…", 0.60) + real_fps = [] + for s in real_smiles: + fp = morgan_fp(s) + if fp is not None: + real_fps.append(fp) + + # 8) Trust score on candidate pool (safe size) + report("Computing trust scores…", 0.70) + trust = compute_trust_scores( + cand, + real_fps=real_fps, + real_smiles=real_smiles, + trust_weights=spec.trust_weights, + ) + cand["trust_score"] = trust + + # 9) Diversity selection on candidate pool + report("Diversity selection…", 0.88) + # score for selection: prioritize Pareto layer 1 then trust + # higher is better + sw_defaults = {"pareto": 0.60, "trust": 0.40} + sw = normalize_weights(spec.selection_weights or {}, sw_defaults) + pareto_bonus = ( + (max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool) + ) / float(max_layers_for_pool) + sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float) + + chosen_idx = greedy_diverse_select( + smiles_list=cand["smiles_key"].tolist(), + scores=sel_score, + max_k=spec.max_candidates, + min_dist=spec.min_distance, + ) + out = cand.iloc[chosen_idx].copy().reset_index(drop=True) + + # 10) Attach Polymer_Name/Class if available (only for matches) + report("Finalizing results…", 0.96) + out = out.set_index("smiles_key", drop=False) + out = out.join(polyinfo, how="left") + out = out.reset_index(drop=True) + + # 11) Make a clean output bundle with requested columns + # Keep SMILES (canonical), name/class, pareto layer, trust score, properties used + keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"] + for p in needed_props: + mc = mean_col(p) + sc = std_col(p) + if mc in out.columns: + keep.append(mc) + if sc in out.columns: + keep.append(sc) + + out = out[keep].rename(columns={"smiles_key": "SMILES"}) + + stats = { + "n_total": float(len(df)), + "n_after_constraints": float(n_after), + "n_pool": float(n_pool), + "n_pareto_pool": float(n_pareto), + "n_selected": float(len(out)), + } + report("Done.", 1.0) + return out, stats, plot_df + + +def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame: + """ + Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer. + Does NOT compute trust/diversity. Safe for live plotting. + """ + rng = np.random.default_rng(spec.random_seed) + + obj_props = [o["property"].lower() for o in spec.objectives] + cons_props = [p.lower() for p in spec.hard_constraints.keys()] + needed_props = sorted(set(obj_props + cons_props)) + + cols = ["SMILES"] + [mean_col(p) for p in needed_props] + df = load_parquet_columns(spec.dataset, columns=cols) + + if "SMILES" not in df.columns and "smiles" in df.columns: + df = df.rename(columns={"smiles": "SMILES"}) + + df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles)) + df = df.dropna(subset=["smiles_key"]).reset_index(drop=True) + + # Hard constraints + for p, rule in spec.hard_constraints.items(): + p = p.lower() + c = mean_col(p) + if c not in df.columns: + return df.iloc[0:0] + if "min" in rule: + df = df[df[c] >= float(rule["min"])] + if "max" in rule: + df = df[df[c] <= float(rule["max"])] + + if len(df) == 0: + return df + + # Pareto cap for plotting + plot_cap = min(int(max_plot_points), int(spec.pareto_max)) + if len(df) > plot_cap: + idx = rng.choice(len(df), size=plot_cap, replace=False) + df = df.iloc[idx].reset_index(drop=True) + + # Build objective matrix (minimization) + X = [] + resolved_obj_props = [] + for o in spec.objectives: + prop = o["property"].lower() + goal = o["goal"].lower() + c = mean_col(prop) + if c not in df.columns: + continue + v = df[c].to_numpy(dtype=float) + if goal == "maximize": + v = -v + X.append(v) + resolved_obj_props.append(prop) + if not X: + fallback_col = next((c for c in df.columns if str(c).startswith("mean_")), None) + if fallback_col is None: + return df.iloc[0:0] + X = [df[fallback_col].to_numpy(dtype=float) * -1.0] + resolved_obj_props = [fallback_col.replace("mean_", "")] + X = np.stack(X, axis=1) + + df["pareto_layer"] = pareto_layers(X, max_layers=5) + + # Return only what plotting needs + keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in resolved_obj_props] + out = df[keep].rename(columns={"smiles_key": "SMILES"}) + return out + + +def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec: + obj = json.loads(text) + pareto_max = int(obj.get("pareto_max", 50000)) + + return DiscoverySpec( + dataset=list(dataset_path), + polyinfo=polyinfo_path, + polyinfo_csv=polyinfo_csv_path, + hard_constraints=obj.get("hard_constraints", {}), + objectives=obj.get("objectives", []), + max_pool=pareto_max, + pareto_max=pareto_max, + max_candidates=int(obj.get("max_candidates", 30)), + max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)), + min_distance=float(obj.get("min_distance", 0.30)), + fingerprint=str(obj.get("fingerprint", "morgan")), + random_seed=int(obj.get("random_seed", 7)), + use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)), + use_full_data=bool(obj.get("use_full_data", False)), + trust_weights=obj.get("trust_weights"), + selection_weights=obj.get("selection_weights"), + ) diff --git a/src/discovery.py b/src/discovery.py new file mode 100644 index 0000000000000000000000000000000000000000..614a4544c5996418ac69bd3c9002cbf0c2eca1b0 --- /dev/null +++ b/src/discovery.py @@ -0,0 +1,767 @@ +# src/discovery.py +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem +from . import sascorer + +# Reuse your canonicalizer if you want; otherwise keep local +def canonicalize_smiles(smiles: str) -> Optional[str]: + s = (smiles or "").strip() + if not s: + return None + m = Chem.MolFromSmiles(s) + if m is None: + return None + return Chem.MolToSmiles(m, canonical=True) + + +# ------------------------- +# Spec schema (minimal v0) +# ------------------------- +@dataclass +class DiscoverySpec: + dataset: List[str] # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"] + polyinfo: str # "POLYINFO_PROPERTY.parquet" + polyinfo_csv: str # "POLYINFO.csv" + + hard_constraints: Dict[str, Dict[str, float]] # { "tg": {"min": 400}, "tc": {"max": 0.3} } + objectives: List[Dict[str, str]] # [{"property":"cp","goal":"maximize"}, ...] + + max_pool: int = 200000 # legacy (kept for compatibility; aligned to pareto_max) + pareto_max: int = 50000 # cap points used for Pareto + diversity fingerprinting + max_candidates: int = 30 # final output size + max_pareto_fronts: int = 5 # how many Pareto layers to keep for candidate pool + min_distance: float = 0.30 # diversity threshold in Tanimoto distance + fingerprint: str = "morgan" # morgan only for now + random_seed: int = 7 + use_canonical_smiles: bool = True + use_full_data: bool = False + trust_weights: Dict[str, float] | None = None + selection_weights: Dict[str, float] | None = None + + +# ------------------------- +# Column mapping +# ------------------------- +def mean_col(prop_key: str) -> str: + return f"mean_{prop_key.lower()}" + +def std_col(prop_key: str) -> str: + return f"std_{prop_key.lower()}" + + +def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]: + out: Dict[str, float] = {} + for k, v in defaults.items(): + try: + vv = float(weights.get(k, v)) + except Exception: + vv = float(v) + out[k] = max(0.0, vv) + s = float(sum(out.values())) + if s <= 0.0: + return defaults.copy() + return {k: float(v / s) for k, v in out.items()} + +def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec: + pareto_max = int(obj.get("pareto_max", 50000)) + return DiscoverySpec( + dataset=list(dataset_path), + polyinfo=polyinfo_path, + polyinfo_csv=polyinfo_csv_path, + hard_constraints=obj.get("hard_constraints", {}), + objectives=obj.get("objectives", []), + # Legacy field kept for compatibility; effectively collapsed to pareto_max. + max_pool=pareto_max, + pareto_max=pareto_max, + max_candidates=int(obj.get("max_candidates", 30)), + max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)), + min_distance=float(obj.get("min_distance", 0.30)), + fingerprint=str(obj.get("fingerprint", "morgan")), + random_seed=int(obj.get("random_seed", 7)), + use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)), + use_full_data=bool(obj.get("use_full_data", False)), + trust_weights=obj.get("trust_weights"), + selection_weights=obj.get("selection_weights"), + ) + +# ------------------------- +# Parquet loading (safe) +# ------------------------- +def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame: + """ + Load only requested columns from Parquet (critical for 1M rows). + Accepts a single path or a list of paths and concatenates rows. + """ + def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame: + available: list[str] + try: + import pyarrow.parquet as pq + + pf = pq.ParquetFile(fp) + available = [str(c) for c in pf.schema.names] + except Exception: + # If schema probing fails, fall back to direct read with requested columns. + return pd.read_parquet(fp, columns=req_cols) + + available_set = set(available) + lower_to_actual = {c.lower(): c for c in available} + + # Resolve requested names against actual parquet schema. + resolved: dict[str, str] = {} + for req in req_cols: + if req in available_set: + resolved[req] = req + continue + alt = lower_to_actual.get(str(req).lower()) + if alt is not None: + resolved[req] = alt + + use_cols = sorted(set(resolved.values())) + if not use_cols: + return pd.DataFrame(columns=req_cols) + + out = pd.read_parquet(fp, columns=use_cols) + for req in req_cols: + src = resolved.get(req) + if src is None: + out[req] = np.nan + elif src != req: + out[req] = out[src] + return out[req_cols] + + if isinstance(path, (list, tuple)): + frames = [_load_one(p, columns) for p in path] + if not frames: + return pd.DataFrame(columns=columns) + return pd.concat(frames, ignore_index=True) + return _load_one(path, columns) + + +def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]: + s = (smiles or "").strip() + if not s: + return None + if not use_canonical_smiles: + # Skip RDKit parsing entirely in fast mode. + return s + m = Chem.MolFromSmiles(s) + if m is None: + return None + if use_canonical_smiles: + return Chem.MolToSmiles(m, canonical=True) + return s + + +def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame: + """ + Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants). + Returns dataframe with index on smiles_key and columns polymer_name/polymer_class. + """ + df = pd.read_csv(polyinfo_csv_path) + + # normalize column names + cols = {c: c for c in df.columns} + # map typical names + if "SMILES" in cols: + df = df.rename(columns={"SMILES": "smiles"}) + elif "smiles" not in df.columns: + raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column") + + if "Polymer_Name" in df.columns: + df = df.rename(columns={"Polymer_Name": "polymer_name"}) + if "polymer_Name" in df.columns: + df = df.rename(columns={"polymer_Name": "polymer_name"}) + if "Polymer_Class" in df.columns: + df = df.rename(columns={"Polymer_Class": "polymer_class"}) + + if "polymer_name" not in df.columns: + df["polymer_name"] = pd.NA + if "polymer_class" not in df.columns: + df["polymer_class"] = pd.NA + + df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles)) + df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key") + df = df.set_index("smiles_key", drop=True) + return df[["polymer_name", "polymer_class"]] + + +# ------------------------- +# Pareto (2–3 objectives) +# ------------------------- +def pareto_front_mask(X: np.ndarray) -> np.ndarray: + """ + Returns mask for nondominated points. + X: (N, M), all objectives assumed to be minimized. + For maximize objectives, we invert before calling this. + """ + N = X.shape[0] + is_efficient = np.ones(N, dtype=bool) + for i in range(N): + if not is_efficient[i]: + continue + # any point that is <= in all dims and < in at least one dominates + dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1) + # if a point dominates i, mark i inefficient + if np.any(dominates): + is_efficient[i] = False + continue + # otherwise, i may dominate others + dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1) + is_efficient[dominated_by_i] = False + is_efficient[i] = True + return is_efficient + + +def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray: + """ + Returns layer index per point: 1 = Pareto front, 2 = second layer, ... + Unassigned points beyond max_layers get 0. + """ + N = X.shape[0] + layers = np.zeros(N, dtype=int) + remaining = np.arange(N) + + layer = 1 + while remaining.size > 0 and layer <= max_layers: + mask = pareto_front_mask(X[remaining]) + front_idx = remaining[mask] + layers[front_idx] = layer + remaining = remaining[~mask] + layer += 1 + return layers + + +def pareto_front_mask_chunked( + X: np.ndarray, + chunk_size: int = 100000, + progress_callback: Optional[Callable[[int, int], None]] = None, +) -> np.ndarray: + """ + Exact global Pareto front mask via chunk-local front reduction + global reconcile. + This is exact for front-1: + 1) compute exact local front within each chunk + 2) union local fronts + 3) compute exact front on the union + """ + N = X.shape[0] + if N <= chunk_size: + if progress_callback is not None: + progress_callback(1, 1) + return pareto_front_mask(X) + + local_front_idx = [] + total_chunks = (N + chunk_size - 1) // chunk_size + done_chunks = 0 + for start in range(0, N, chunk_size): + end = min(start + chunk_size, N) + idx = np.arange(start, end) + mask_local = pareto_front_mask(X[idx]) + local_front_idx.append(idx[mask_local]) + done_chunks += 1 + if progress_callback is not None: + progress_callback(done_chunks, total_chunks) + + if not local_front_idx: + return np.zeros(N, dtype=bool) + + reduced_idx = np.concatenate(local_front_idx) + reduced_mask = pareto_front_mask(X[reduced_idx]) + front_idx = reduced_idx[reduced_mask] + + out = np.zeros(N, dtype=bool) + out[front_idx] = True + return out + + +def pareto_layers_chunked( + X: np.ndarray, + max_layers: int = 10, + chunk_size: int = 100000, + progress_callback: Optional[Callable[[int, int, int], None]] = None, +) -> np.ndarray: + """ + Exact Pareto layers using repeated exact chunked front extraction. + """ + N = X.shape[0] + layers = np.zeros(N, dtype=int) + remaining = np.arange(N) + layer = 1 + + while remaining.size > 0 and layer <= max_layers: + def on_chunk(done: int, total: int) -> None: + if progress_callback is not None: + progress_callback(layer, done, total) + + mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk) + front_idx = remaining[mask] + layers[front_idx] = layer + remaining = remaining[~mask] + layer += 1 + + return layers + + +# ------------------------- +# Fingerprints & diversity +# ------------------------- +def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048): + m = Chem.MolFromSmiles(smiles) + if m is None: + return None + return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits) + +def tanimoto_distance(fp1, fp2) -> float: + return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2) + +def greedy_diverse_select( + smiles_list: List[str], + scores: np.ndarray, + max_k: int, + min_dist: float, +) -> List[int]: + """ + Greedy selection by descending score, enforcing min Tanimoto distance. + Returns indices into smiles_list. + """ + fps = [] + valid_idx = [] + for i, s in enumerate(smiles_list): + fp = morgan_fp(s) + if fp is not None: + fps.append(fp) + valid_idx.append(i) + + if not valid_idx: + return [] + + # rank candidates (higher score first) + order = np.argsort(-scores[valid_idx]) + selected_global = [] + selected_fps = [] + + for oi in order: + i = valid_idx[oi] + fp_i = fps[oi] # aligned with valid_idx + ok = True + for fp_j in selected_fps: + if tanimoto_distance(fp_i, fp_j) < min_dist: + ok = False + break + if ok: + selected_global.append(i) + selected_fps.append(fp_i) + if len(selected_global) >= max_k: + break + + return selected_global + + +# ------------------------- +# Trust score (lightweight, robust) +# ------------------------- +def internal_consistency_penalty(row: pd.Series) -> float: + """ + Very simple physics/validity checks. Penalty in [0,1]. + Adjust/add rules later. + """ + viol = 0 + total = 0 + + def chk(cond: bool): + nonlocal viol, total + total += 1 + if not cond: + viol += 1 + + # positivity checks if present + for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]: + c = mean_col(p) + if c in row.index and pd.notna(row[c]): + if p in ["bandgap", "tg", "tm"]: + chk(float(row[c]) >= 0.0) + else: + chk(float(row[c]) > 0.0) + + # Poisson ratio bounds if present + if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]): + v = float(row[mean_col("poisson")]) + chk(0.0 <= v <= 0.5) + + # Tg <= Tm if both present + if mean_col("tg") in row.index and mean_col("tm") in row.index: + if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]): + chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")])) + + if total == 0: + return 0.0 + return viol / total + + +def synthesizability_score(smiles: str) -> float: + """ + RDKit SA-score based synthesizability proxy in [0,1]. + SA-score is ~[1 (easy), 10 (hard)]. + We map: 1 -> 1.0, 10 -> 0.0 + """ + m = Chem.MolFromSmiles(smiles) + if m is None: + return 0.0 + + # Guard against unexpected scorer failures / None for edge-case molecules. + try: + sa_raw = sascorer.calculateScore(m) + except Exception: + return 0.0 + if sa_raw is None: + return 0.0 + + sa = float(sa_raw) # ~ 1..10 + s_syn = 1.0 - (sa - 1.0) / 9.0 # linear map to [0,1] + return float(np.clip(s_syn, 0.0, 1.0)) + + +def compute_trust_scores( + df: pd.DataFrame, + real_fps: List, + real_smiles: List[str], + trust_weights: Dict[str, float] | None = None, +) -> np.ndarray: + """ + Trust score in [0,1] (higher = more trustworthy / lower risk). + Components: + - distance to nearest real polymer (fingerprint distance) + - internal consistency penalty + - uncertainty penalty (if std columns exist) + - synthesizability + """ + N = len(df) + trust = np.zeros(N, dtype=float) + tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20} + tw = normalize_weights(trust_weights or {}, tw_defaults) + + # nearest-real distance (expensive if done naively) + # We do it only for the (small) post-filter set, which is safe. + smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon" + for i in range(N): + s = df.iloc[i][smiles_col] + fp = morgan_fp(s) + if fp is None or not real_fps: + d_real = 1.0 + else: + sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps) + d_real = 1.0 - float(max(sims)) # distance to nearest + + # internal consistency + pen_cons = internal_consistency_penalty(df.iloc[i]) + + # uncertainty: average normalized std for any std_* columns present + std_cols = [c for c in df.columns if c.startswith("std_")] + if std_cols: + std_vals = df.iloc[i][std_cols].astype(float) + std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna() + pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0 + else: + pen_unc = 0.0 + + # synthesizability heuristic + s_syn = synthesizability_score(s) + + # Combine (tunable weights) + # lower distance to real is better -> convert to score + s_real = 1.0 - np.clip(d_real, 0.0, 1.0) + + trust[i] = ( + tw["real"] * s_real + + tw["consistency"] * (1.0 - pen_cons) + + tw["uncertainty"] * (1.0 - pen_unc) + + tw["synth"] * s_syn + ) + + trust = np.clip(trust, 0.0, 1.0) + return trust + + +# ------------------------- +# Main pipeline +# ------------------------- +def run_discovery( + spec: DiscoverySpec, + progress_callback: Optional[Callable[[str, float], None]] = None, +) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]: + def report(step: str, pct: float) -> None: + if progress_callback is not None: + progress_callback(step, pct) + + rng = np.random.default_rng(spec.random_seed) + + # 1) Determine required columns + report("Preparing columns…", 0.02) + obj_props = [o["property"].lower() for o in spec.objectives] + cons_props = [p.lower() for p in spec.hard_constraints.keys()] + + needed_props = sorted(set(obj_props + cons_props)) + cols = ["SMILES"] + [mean_col(p) for p in needed_props] + + # include std columns if available (not required, but used for trust) + std_cols = [std_col(p) for p in needed_props] + cols += std_cols + + # 2) Load only needed columns + report("Loading data from parquet…", 0.05) + df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"]) + # normalize + if "SMILES" not in df.columns and "smiles" in df.columns: + df = df.rename(columns={"smiles": "SMILES"}) + normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…" + report(normalize_step, 0.10) + df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles)) + df = df.dropna(subset=["smiles_key"]).reset_index(drop=True) + + # 3) Hard constraints + report("Applying constraints…", 0.22) + for p, rule in spec.hard_constraints.items(): + p = p.lower() + c = mean_col(p) + if c not in df.columns: + # if missing, nothing can satisfy + df = df.iloc[0:0] + break + if "min" in rule: + df = df[df[c] >= float(rule["min"])] + if "max" in rule: + df = df[df[c] <= float(rule["max"])] + + n_after = len(df) + if n_after == 0: + empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0} + return df, empty_stats, pd.DataFrame() + + n_pool = len(df) + + # 5) Prepare objective matrix for Pareto + report("Building objective matrix…", 0.30) + # convert to minimization: maximize => negate + X = [] + for o in spec.objectives: + prop = o["property"].lower() + goal = o["goal"].lower() + c = mean_col(prop) + if c not in df.columns: + raise ValueError(f"Objective column missing: {c}") + v = df[c].to_numpy(dtype=float) + if goal == "maximize": + v = -v + X.append(v) + X = np.stack(X, axis=1) # (N, M) + + # Pareto cap before computing layers (optional safety) + if spec.use_full_data: + report("Using full dataset (no Pareto cap)…", 0.35) + elif len(df) > spec.pareto_max: + idx = rng.choice(len(df), size=spec.pareto_max, replace=False) + df = df.iloc[idx].reset_index(drop=True) + X = X[idx] + + # 6) Pareto layers (only 5 layers needed for candidate pool) + report("Computing Pareto layers…", 0.40) + pareto_start = 0.40 + pareto_end = 0.54 + max_layers_for_pool = max(1, int(spec.max_pareto_fronts)) + pareto_chunk_ref = {"chunks_per_layer": None} + + def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None: + if pareto_chunk_ref["chunks_per_layer"] is None: + pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks)) + ref_chunks = pareto_chunk_ref["chunks_per_layer"] + total_units = max_layers_for_pool * ref_chunks + done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks) + pareto_pct = int(round(100.0 * done_units / max(1, total_units))) + + layer_progress = done_chunks / max(1, total_chunks) + overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool + pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall)) + report( + f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})", + pct, + ) + + layers = pareto_layers_chunked( + X, + max_layers=max_layers_for_pool, + chunk_size=100000, + progress_callback=on_pareto_chunk, + ) + report("Computing Pareto layers…", pareto_end) + df["pareto_layer"] = layers + plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy() + plot_df = plot_df.rename(columns={"smiles_key": "SMILES"}) + + # Keep first few layers as candidate pool (avoid huge set) + cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy() + if cand.empty: + cand = df[df["pareto_layer"] == 1].copy() + cand = cand.reset_index(drop=True) + n_pareto = len(cand) + + # 7) Load real polymer metadata and fingerprints (from POLYINFO.csv) + report("Loading POLYINFO index…", 0.55) + polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles) + real_smiles = polyinfo.index.to_list() + + report("Building real-polymer fingerprints…", 0.60) + real_fps = [] + for s in real_smiles: + fp = morgan_fp(s) + if fp is not None: + real_fps.append(fp) + + # 8) Trust score on candidate pool (safe size) + report("Computing trust scores…", 0.70) + trust = compute_trust_scores( + cand, + real_fps=real_fps, + real_smiles=real_smiles, + trust_weights=spec.trust_weights, + ) + cand["trust_score"] = trust + + # 9) Diversity selection on candidate pool + report("Diversity selection…", 0.88) + # score for selection: prioritize Pareto layer 1 then trust + # higher is better + sw_defaults = {"pareto": 0.60, "trust": 0.40} + sw = normalize_weights(spec.selection_weights or {}, sw_defaults) + pareto_bonus = ( + (max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool) + ) / float(max_layers_for_pool) + sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float) + + chosen_idx = greedy_diverse_select( + smiles_list=cand["smiles_key"].tolist(), + scores=sel_score, + max_k=spec.max_candidates, + min_dist=spec.min_distance, + ) + out = cand.iloc[chosen_idx].copy().reset_index(drop=True) + + # 10) Attach Polymer_Name/Class if available (only for matches) + report("Finalizing results…", 0.96) + out = out.set_index("smiles_key", drop=False) + out = out.join(polyinfo, how="left") + out = out.reset_index(drop=True) + + # 11) Make a clean output bundle with requested columns + # Keep SMILES (canonical), name/class, pareto layer, trust score, properties used + keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"] + for p in needed_props: + mc = mean_col(p) + sc = std_col(p) + if mc in out.columns: + keep.append(mc) + if sc in out.columns: + keep.append(sc) + + out = out[keep].rename(columns={"smiles_key": "SMILES"}) + + stats = { + "n_total": float(len(df)), + "n_after_constraints": float(n_after), + "n_pool": float(n_pool), + "n_pareto_pool": float(n_pareto), + "n_selected": float(len(out)), + } + report("Done.", 1.0) + return out, stats, plot_df + + +def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame: + """ + Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer. + Does NOT compute trust/diversity. Safe for live plotting. + """ + rng = np.random.default_rng(spec.random_seed) + + obj_props = [o["property"].lower() for o in spec.objectives] + cons_props = [p.lower() for p in spec.hard_constraints.keys()] + needed_props = sorted(set(obj_props + cons_props)) + + cols = ["SMILES"] + [mean_col(p) for p in needed_props] + df = load_parquet_columns(spec.dataset, columns=cols) + + if "SMILES" not in df.columns and "smiles" in df.columns: + df = df.rename(columns={"smiles": "SMILES"}) + + df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles)) + df = df.dropna(subset=["smiles_key"]).reset_index(drop=True) + + # Hard constraints + for p, rule in spec.hard_constraints.items(): + p = p.lower() + c = mean_col(p) + if c not in df.columns: + return df.iloc[0:0] + if "min" in rule: + df = df[df[c] >= float(rule["min"])] + if "max" in rule: + df = df[df[c] <= float(rule["max"])] + + if len(df) == 0: + return df + + # Pareto cap for plotting + plot_cap = min(int(max_plot_points), int(spec.pareto_max)) + if len(df) > plot_cap: + idx = rng.choice(len(df), size=plot_cap, replace=False) + df = df.iloc[idx].reset_index(drop=True) + + # Build objective matrix (minimization) + X = [] + for o in spec.objectives: + prop = o["property"].lower() + goal = o["goal"].lower() + c = mean_col(prop) + v = df[c].to_numpy(dtype=float) + if goal == "maximize": + v = -v + X.append(v) + X = np.stack(X, axis=1) + + df["pareto_layer"] = pareto_layers(X, max_layers=5) + + # Return only what plotting needs + keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in obj_props] + out = df[keep].rename(columns={"smiles_key": "SMILES"}) + return out + + +def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec: + obj = json.loads(text) + pareto_max = int(obj.get("pareto_max", 50000)) + + return DiscoverySpec( + dataset=list(dataset_path), + polyinfo=polyinfo_path, + polyinfo_csv=polyinfo_csv_path, + hard_constraints=obj.get("hard_constraints", {}), + objectives=obj.get("objectives", []), + max_pool=pareto_max, + pareto_max=pareto_max, + max_candidates=int(obj.get("max_candidates", 30)), + max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)), + min_distance=float(obj.get("min_distance", 0.30)), + fingerprint=str(obj.get("fingerprint", "morgan")), + random_seed=int(obj.get("random_seed", 7)), + use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)), + use_full_data=bool(obj.get("use_full_data", False)), + trust_weights=obj.get("trust_weights"), + selection_weights=obj.get("selection_weights"), + ) diff --git a/src/fpscores.pkl.gz b/src/fpscores.pkl.gz new file mode 100644 index 0000000000000000000000000000000000000000..aa6f88c9c3fa56161b7df08e74ea6824f3071d08 --- /dev/null +++ b/src/fpscores.pkl.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10dcef9340c873e7b987924461b0af5365eb8dd96be607203debe8ddf80c1e73 +size 3848394 diff --git a/src/literature_service/__init__.py b/src/literature_service/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..638c40d182341516ef3039c885a6e9f02207d25b --- /dev/null +++ b/src/literature_service/__init__.py @@ -0,0 +1,34 @@ +from .db import Database, get_database +from .repository import ( + ProjectRepo, + PaperRepo, + RunRepo, + DataPointRepo, + ExtractionJobRepo, + ManualUploadRepo, + QuerySessionRepo, + PageIndexRepo, + QAMessageRepo, +) +from .pipeline import LiteraturePipeline +from .manual_upload import ManualUploadService +from .query_intent import QueryIntentService +from .pageindex_client import PageIndexService + +__all__ = [ + "Database", + "get_database", + "ProjectRepo", + "PaperRepo", + "RunRepo", + "DataPointRepo", + "ExtractionJobRepo", + "ManualUploadRepo", + "QuerySessionRepo", + "PageIndexRepo", + "QAMessageRepo", + "LiteraturePipeline", + "ManualUploadService", + "QueryIntentService", + "PageIndexService", +] diff --git a/src/literature_service/__pycache__/__init__.cpython-310.pyc b/src/literature_service/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a6a3c80773f577e038d217a860b0c00c8ba609 Binary files /dev/null and b/src/literature_service/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/__init__.cpython-313.pyc b/src/literature_service/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db31a93413804c84698956d03da15eba014e1c0f Binary files /dev/null and b/src/literature_service/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/literature_service/__pycache__/db.cpython-310.pyc b/src/literature_service/__pycache__/db.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56bb2693a7e7b48a8d42a5cab3947a5f6489b096 Binary files /dev/null and b/src/literature_service/__pycache__/db.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/db.cpython-313.pyc b/src/literature_service/__pycache__/db.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b426b460081459e5079a03dd4ccf1b5f2196fb7c Binary files /dev/null and b/src/literature_service/__pycache__/db.cpython-313.pyc differ diff --git a/src/literature_service/__pycache__/manual_upload.cpython-310.pyc b/src/literature_service/__pycache__/manual_upload.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2dbaa93f7084a91f29dcbf717c951c82d1b1033 Binary files /dev/null and b/src/literature_service/__pycache__/manual_upload.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/manual_upload.cpython-313.pyc b/src/literature_service/__pycache__/manual_upload.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdb228d7a061729fde2d5384492ef3c97899ca0d Binary files /dev/null and b/src/literature_service/__pycache__/manual_upload.cpython-313.pyc differ diff --git a/src/literature_service/__pycache__/pageindex_client.cpython-310.pyc b/src/literature_service/__pycache__/pageindex_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66844e0d235e0b3f257d5671d7e16bb825d42072 Binary files /dev/null and b/src/literature_service/__pycache__/pageindex_client.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/pageindex_client.cpython-313.pyc b/src/literature_service/__pycache__/pageindex_client.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4b1ba224f6563ad2f3e6ad1cf20478bd5279979 Binary files /dev/null and b/src/literature_service/__pycache__/pageindex_client.cpython-313.pyc differ diff --git a/src/literature_service/__pycache__/pipeline.cpython-310.pyc b/src/literature_service/__pycache__/pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c4f2b6e57eb26774b61a3ccd03075c493224c64 Binary files /dev/null and b/src/literature_service/__pycache__/pipeline.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/pipeline.cpython-313.pyc b/src/literature_service/__pycache__/pipeline.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af958a4f53c1ad03f01d78ea3dd6715fe72aceb Binary files /dev/null and b/src/literature_service/__pycache__/pipeline.cpython-313.pyc differ diff --git a/src/literature_service/__pycache__/query_intent.cpython-310.pyc b/src/literature_service/__pycache__/query_intent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd9f1d3bc9fc91ce93aae9a5142775848642a513 Binary files /dev/null and b/src/literature_service/__pycache__/query_intent.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/query_intent.cpython-313.pyc b/src/literature_service/__pycache__/query_intent.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53f330a06e37cf1b0e067676331f77d8df83a85 Binary files /dev/null and b/src/literature_service/__pycache__/query_intent.cpython-313.pyc differ diff --git a/src/literature_service/__pycache__/repository.cpython-310.pyc b/src/literature_service/__pycache__/repository.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf447225571f49a30d96a411ba91c6cc514d80e Binary files /dev/null and b/src/literature_service/__pycache__/repository.cpython-310.pyc differ diff --git a/src/literature_service/__pycache__/repository.cpython-313.pyc b/src/literature_service/__pycache__/repository.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..424d06945f6a66c53bb215e94d5193ba6b5828a9 Binary files /dev/null and b/src/literature_service/__pycache__/repository.cpython-313.pyc differ diff --git a/src/literature_service/db.py b/src/literature_service/db.py new file mode 100644 index 0000000000000000000000000000000000000000..cc40f8e77cd36339bed243112d6ae01810c882fa --- /dev/null +++ b/src/literature_service/db.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +import sqlite3 +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Iterator, Optional + + +LATEST_MIGRATION_VERSION = "0004_literature_evidence_staging" + + +def utc_now_iso() -> str: + return datetime.now(timezone.utc).isoformat() + + +SCHEMA_SQL = """ +CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + status TEXT NOT NULL DEFAULT 'active', + default_model_provider TEXT NOT NULL DEFAULT 'openai_compatible', + default_model_name TEXT NOT NULL DEFAULT 'gpt-oss:latest', + target_properties_json TEXT NOT NULL DEFAULT '[]', + extraction_instructions TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS papers ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + external_paper_id TEXT NOT NULL, + source TEXT NOT NULL, + title TEXT NOT NULL, + authors_json TEXT NOT NULL DEFAULT '[]', + year INTEGER, + doi TEXT, + abstract TEXT, + venue TEXT, + citation_count INTEGER, + is_open_access INTEGER NOT NULL DEFAULT 0, + url TEXT, + landing_url TEXT, + pdf_url TEXT, + pdf_path TEXT, + match_reasons_json TEXT NOT NULL DEFAULT '[]', + download_status TEXT NOT NULL DEFAULT 'discovered', + download_error TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS uq_papers_project_external +ON papers(project_id, external_paper_id); + +CREATE TABLE IF NOT EXISTS extraction_runs ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + query TEXT NOT NULL, + strategy TEXT NOT NULL, + model_provider TEXT NOT NULL, + model_name TEXT NOT NULL, + status TEXT NOT NULL, + started_at TEXT NOT NULL, + ended_at TEXT, + stats_json TEXT NOT NULL DEFAULT '{}', + error_message TEXT, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS contextual_data_points ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + run_id TEXT, + polymer_name TEXT NOT NULL, + dopant TEXT, + dopant_ratio TEXT, + property_name TEXT NOT NULL, + raw_value TEXT NOT NULL, + raw_unit TEXT NOT NULL, + standardized_value REAL, + standardized_unit TEXT, + conditions_json TEXT NOT NULL DEFAULT '{}', + source_quote TEXT NOT NULL, + source_location TEXT, + extraction_confidence REAL NOT NULL DEFAULT 0.5, + quality_tier TEXT NOT NULL DEFAULT 'bronze', + validation_status TEXT NOT NULL DEFAULT 'pending', + reviewer_note TEXT, + edited_payload_json TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE, + FOREIGN KEY (run_id) REFERENCES extraction_runs(id) ON DELETE SET NULL +); + +CREATE INDEX IF NOT EXISTS idx_points_project_paper_property_status +ON contextual_data_points(project_id, paper_id, property_name, validation_status); + +CREATE TABLE IF NOT EXISTS literature_extraction_jobs ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + extractor_version TEXT NOT NULL, + model_provider TEXT, + model_name TEXT, + status TEXT NOT NULL DEFAULT 'pending', + stats_json TEXT NOT NULL DEFAULT '{}', + error_message TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE +); + +CREATE UNIQUE INDEX IF NOT EXISTS uq_literature_extraction_jobs_project_paper_version +ON literature_extraction_jobs(project_id, paper_id, extractor_version); + +CREATE TABLE IF NOT EXISTS literature_evidence ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + extraction_job_id TEXT, + material_name TEXT NOT NULL, + canonical_smiles TEXT, + property_key TEXT NOT NULL, + raw_value TEXT NOT NULL, + raw_unit TEXT NOT NULL, + standardized_value REAL, + standardized_unit TEXT, + conditions_json TEXT NOT NULL DEFAULT '{}', + method TEXT, + evidence_quote TEXT NOT NULL, + evidence_location TEXT, + extractor_version TEXT NOT NULL, + extraction_model TEXT, + extraction_confidence REAL NOT NULL DEFAULT 0.5, + quality_tier TEXT NOT NULL DEFAULT 'bronze', + review_status TEXT NOT NULL DEFAULT 'pending', + reviewer_note TEXT, + edited_payload_json TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE, + FOREIGN KEY (extraction_job_id) REFERENCES literature_extraction_jobs(id) ON DELETE SET NULL +); + +CREATE INDEX IF NOT EXISTS idx_literature_evidence_project_property_review +ON literature_evidence(project_id, property_key, review_status); + +CREATE INDEX IF NOT EXISTS idx_literature_evidence_project_paper +ON literature_evidence(project_id, paper_id); + +CREATE TABLE IF NOT EXISTS literature_review_events ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + evidence_id TEXT NOT NULL, + action TEXT NOT NULL, + from_status TEXT, + to_status TEXT NOT NULL, + reviewer_note TEXT, + payload_json TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (evidence_id) REFERENCES literature_evidence(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS manual_uploads ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + uploaded_filename TEXT NOT NULL, + stored_path TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'stored', + error_message TEXT, + uploaded_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS query_sessions ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + original_query TEXT NOT NULL, + suggestions_json TEXT NOT NULL DEFAULT '[]', + clarification_required INTEGER NOT NULL DEFAULT 0, + clarification_payload_json TEXT NOT NULL DEFAULT '{}', + status TEXT NOT NULL DEFAULT 'ready', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS pageindex_documents ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + doc_id TEXT, + status TEXT NOT NULL DEFAULT 'pending', + last_polled_at TEXT, + error_message TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE +); + +CREATE TABLE IF NOT EXISTS qa_messages ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + session_id TEXT NOT NULL, + role TEXT NOT NULL, + content TEXT NOT NULL, + doc_ids_json TEXT NOT NULL DEFAULT '[]', + model_provider TEXT, + model_name TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE +); +""" + + +def _recreate_views(conn: sqlite3.Connection) -> None: + conn.execute("DROP VIEW IF EXISTS v_literature_points_flat") + conn.execute("DROP VIEW IF EXISTS v_literature_evidence_flat") + conn.executescript( + """ + CREATE VIEW v_literature_points_flat AS + SELECT + le.id AS point_id, + le.project_id, + le.paper_id, + le.extraction_job_id AS run_id, + le.material_name AS polymer_name, + json_extract(le.conditions_json, '$.dopant') AS dopant, + json_extract(le.conditions_json, '$.dopant_ratio') AS dopant_ratio, + le.property_key AS property_name, + le.raw_value, + le.raw_unit, + le.standardized_value, + le.standardized_unit, + json_extract(le.conditions_json, '$.solvent') AS solvent, + json_extract(le.conditions_json, '$.annealing_temp_c') AS annealing_temp_c, + json_extract(le.conditions_json, '$.annealing_time_min') AS annealing_time_min, + json_extract(le.conditions_json, '$.spin_speed_rpm') AS spin_speed_rpm, + json_extract(le.conditions_json, '$.measurement_temp_k') AS measurement_temp_k, + le.evidence_quote AS source_quote, + le.evidence_location AS source_location, + le.extraction_confidence, + le.quality_tier, + le.review_status AS validation_status, + le.reviewer_note, + le.created_at, + le.updated_at + FROM literature_evidence le + UNION ALL + SELECT + p.id AS point_id, + p.project_id, + p.paper_id, + p.run_id, + p.polymer_name, + p.dopant, + p.dopant_ratio, + p.property_name, + p.raw_value, + p.raw_unit, + p.standardized_value, + p.standardized_unit, + json_extract(p.conditions_json, '$.solvent') AS solvent, + json_extract(p.conditions_json, '$.annealing_temp_c') AS annealing_temp_c, + json_extract(p.conditions_json, '$.annealing_time_min') AS annealing_time_min, + json_extract(p.conditions_json, '$.spin_speed_rpm') AS spin_speed_rpm, + json_extract(p.conditions_json, '$.measurement_temp_k') AS measurement_temp_k, + p.source_quote, + p.source_location, + p.extraction_confidence, + p.quality_tier, + p.validation_status, + p.reviewer_note, + p.created_at, + p.updated_at + FROM contextual_data_points p; + + CREATE VIEW v_literature_evidence_flat AS + SELECT + le.*, + p.title AS paper_title, + p.year AS paper_year, + p.venue AS paper_venue, + p.doi AS paper_doi, + p.landing_url, + p.pdf_url, + p.is_open_access + FROM literature_evidence le + JOIN papers p ON p.id = le.paper_id; + """ + ) + + +def _has_column(conn: sqlite3.Connection, table: str, column: str) -> bool: + cols = {r[1] for r in conn.execute(f"PRAGMA table_info({table})").fetchall()} + return column in cols + + +@dataclass +class Database: + db_path: Path + + def __post_init__(self) -> None: + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_db() + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + return conn + + def _init_db(self) -> None: + with self._connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TEXT NOT NULL + ) + """ + ) + + base_row = conn.execute( + "SELECT version FROM schema_migrations WHERE version = ?", + ("0001_literature_project_base",), + ).fetchone() + if base_row is None: + conn.executescript(SCHEMA_SQL) + _recreate_views(conn) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", + ("0001_literature_project_base", utc_now_iso()), + ) + else: + conn.executescript(SCHEMA_SQL) + + self._apply_migration_0002(conn) + self._apply_migration_0003(conn) + self._apply_migration_0004(conn) + conn.commit() + + def _apply_migration_0002(self, conn: sqlite3.Connection) -> None: + version = "0002_project_extraction_schema" + row = conn.execute("SELECT version FROM schema_migrations WHERE version = ?", (version,)).fetchone() + if row is not None: + return + if not _has_column(conn, "projects", "target_properties_json"): + conn.execute("ALTER TABLE projects ADD COLUMN target_properties_json TEXT NOT NULL DEFAULT '[]'") + if not _has_column(conn, "projects", "extraction_instructions"): + conn.execute("ALTER TABLE projects ADD COLUMN extraction_instructions TEXT NOT NULL DEFAULT ''") + conn.execute( + "INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", + (version, utc_now_iso()), + ) + + def _apply_migration_0003(self, conn: sqlite3.Connection) -> None: + version = "0003_paper_card_metadata" + row = conn.execute("SELECT version FROM schema_migrations WHERE version = ?", (version,)).fetchone() + if row is not None: + return + paper_columns = { + "abstract": "TEXT", + "venue": "TEXT", + "citation_count": "INTEGER", + "is_open_access": "INTEGER NOT NULL DEFAULT 0", + "landing_url": "TEXT", + "match_reasons_json": "TEXT NOT NULL DEFAULT '[]'", + } + for name, ddl in paper_columns.items(): + if not _has_column(conn, "papers", name): + conn.execute(f"ALTER TABLE papers ADD COLUMN {name} {ddl}") + conn.execute( + "INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", + (version, utc_now_iso()), + ) + + def _apply_migration_0004(self, conn: sqlite3.Connection) -> None: + version = "0004_literature_evidence_staging" + row = conn.execute("SELECT version FROM schema_migrations WHERE version = ?", (version,)).fetchone() + if row is not None: + _recreate_views(conn) + return + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS literature_extraction_jobs ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + extractor_version TEXT NOT NULL, + model_provider TEXT, + model_name TEXT, + status TEXT NOT NULL DEFAULT 'pending', + stats_json TEXT NOT NULL DEFAULT '{}', + error_message TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE + ); + + CREATE UNIQUE INDEX IF NOT EXISTS uq_literature_extraction_jobs_project_paper_version + ON literature_extraction_jobs(project_id, paper_id, extractor_version); + + CREATE TABLE IF NOT EXISTS literature_evidence ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + paper_id TEXT NOT NULL, + extraction_job_id TEXT, + material_name TEXT NOT NULL, + canonical_smiles TEXT, + property_key TEXT NOT NULL, + raw_value TEXT NOT NULL, + raw_unit TEXT NOT NULL, + standardized_value REAL, + standardized_unit TEXT, + conditions_json TEXT NOT NULL DEFAULT '{}', + method TEXT, + evidence_quote TEXT NOT NULL, + evidence_location TEXT, + extractor_version TEXT NOT NULL, + extraction_model TEXT, + extraction_confidence REAL NOT NULL DEFAULT 0.5, + quality_tier TEXT NOT NULL DEFAULT 'bronze', + review_status TEXT NOT NULL DEFAULT 'pending', + reviewer_note TEXT, + edited_payload_json TEXT, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (paper_id) REFERENCES papers(id) ON DELETE CASCADE, + FOREIGN KEY (extraction_job_id) REFERENCES literature_extraction_jobs(id) ON DELETE SET NULL + ); + + CREATE INDEX IF NOT EXISTS idx_literature_evidence_project_property_review + ON literature_evidence(project_id, property_key, review_status); + + CREATE INDEX IF NOT EXISTS idx_literature_evidence_project_paper + ON literature_evidence(project_id, paper_id); + + CREATE TABLE IF NOT EXISTS literature_review_events ( + id TEXT PRIMARY KEY, + project_id TEXT NOT NULL, + evidence_id TEXT NOT NULL, + action TEXT NOT NULL, + from_status TEXT, + to_status TEXT NOT NULL, + reviewer_note TEXT, + payload_json TEXT, + created_at TEXT NOT NULL, + FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, + FOREIGN KEY (evidence_id) REFERENCES literature_evidence(id) ON DELETE CASCADE + ); + """ + ) + _recreate_views(conn) + conn.execute( + "INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", + (version, utc_now_iso()), + ) + + @contextmanager + def connect(self) -> Iterator[sqlite3.Connection]: + conn = self._connect() + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + +_DB_SINGLETON: Optional[Database] = None + + +def get_database(db_path: str | Path = "data/app.db") -> Database: + global _DB_SINGLETON + path = Path(db_path) + if _DB_SINGLETON is None or _DB_SINGLETON.db_path != path: + _DB_SINGLETON = Database(path) + return _DB_SINGLETON diff --git a/src/literature_service/manual_upload.py b/src/literature_service/manual_upload.py new file mode 100644 index 0000000000000000000000000000000000000000..4f7a7a9a1d7e9036a97b23c20d68ef00ae145dc9 --- /dev/null +++ b/src/literature_service/manual_upload.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + +from .repository import ManualUploadRepo, PaperRepo + + +@dataclass +class ManualUploadService: + paper_repo: PaperRepo + upload_repo: ManualUploadRepo + storage_root: Path = Path("data/literature/projects") + + def attach_manual_pdf( + self, + project_id: str, + paper_id: str, + uploaded_filename: str, + content: bytes, + ) -> Dict[str, Any]: + if not uploaded_filename.lower().endswith(".pdf"): + raise ValueError("Only PDF files are supported") + if not content: + raise ValueError("Uploaded file is empty") + + paper = self.paper_repo.get_paper(project_id, paper_id) + if not paper: + raise ValueError(f"paper not found: {paper_id}") + + safe_name = f"{paper_id}.pdf" + target_dir = self.storage_root / project_id / "raw_pdfs" + target_dir.mkdir(parents=True, exist_ok=True) + target_path = target_dir / safe_name + target_path.write_bytes(content) + + updated = self.paper_repo.mark_uploaded_manual(project_id, paper_id, str(target_path)) + upload = self.upload_repo.create_upload( + project_id=project_id, + paper_id=paper_id, + uploaded_filename=uploaded_filename, + stored_path=str(target_path), + status="stored", + ) + return {"paper": updated, "upload": upload} diff --git a/src/literature_service/pageindex_client.py b/src/literature_service/pageindex_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4d11e100ee1f1a4d556fc85c998bd02a5a83f451 --- /dev/null +++ b/src/literature_service/pageindex_client.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + + +@dataclass +class PageIndexService: + api_key: str + + def _client(self): + try: + from pageindex import PageIndexClient # type: ignore + except ImportError as exc: + raise RuntimeError("pageindex SDK not installed. Install with: pip install -U pageindex") from exc + return PageIndexClient(api_key=self.api_key) + + def submit_document(self, path: str) -> str: + client = self._client() + result = client.submit_document(path) + doc_id = result.get("doc_id") + if not doc_id: + raise RuntimeError(f"submit_document returned no doc_id: {result}") + return doc_id + + def get_document_status(self, doc_id: str) -> str: + client = self._client() + status = client.get_document(doc_id).get("status") + return status or "unknown" + + def chat_completions(self, message: str, doc_id: str) -> str: + client = self._client() + response = client.chat_completions( + messages=[{"role": "user", "content": message}], + doc_id=doc_id, + ) + choices = response.get("choices", []) + if not choices: + return "" + message_obj = choices[0].get("message", {}) + return message_obj.get("content", "") diff --git a/src/literature_service/pipeline.py b/src/literature_service/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee4e0e1e855892404b1126324421f999fc5acfa --- /dev/null +++ b/src/literature_service/pipeline.py @@ -0,0 +1,641 @@ +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional + +from literature import ( + ContextualizedExtractor, + DataExtractor, + PaperDiscoveryAgent, + PDFRetriever, + QualityAssessor, + UnitStandardizer, + extract_text_from_pdf, +) +from literature.extraction import is_expected_skip_error +from literature.property_registry import detect_property_keys, normalize_property_key, property_display_name +from literature.schemas import ( + ContextualizedValue, + ExperimentalConditions, + LiteratureQuerySpec, + LiteratureSupportSummary, + PaperCardResult, + PaperMetadata, + PaperSource, + PolymerDataPoint, +) + +from .db import get_database +from .query_intent import QueryIntentService +from .repository import ( + DataPointRepo, + ExtractionJobRepo, + PaperRepo, + ProjectRepo, + QuerySessionRepo, + RunRepo, +) + +logger = logging.getLogger(__name__) + +SIMULATION_KEYWORDS = [ + "simulation", + "molecular dynamics", + "finite element", + "dft", + "density functional theory", + "monte carlo", +] + + +def _legacy_point_to_contextual(dp: PolymerDataPoint) -> List[ContextualizedValue]: + out: List[ContextualizedValue] = [] + measurements = [ + ("electrical_conductivity", dp.electrical_conductivity_s_cm, "S/cm"), + ("thermal_conductivity", dp.thermal_conductivity_w_mk, "W/(m*K)"), + ("seebeck_coefficient", dp.seebeck_coefficient_uv_k, "uV/K"), + ("power_factor", dp.power_factor_uw_m_k2, "uW/(m*K^2)"), + ("zt_figure_of_merit", dp.zt_figure_of_merit, ""), + ] + conditions = ExperimentalConditions( + solvent=dp.solvent, + concentration_mg_ml=dp.concentration_mg_ml, + spin_speed_rpm=dp.spin_speed_rpm, + spin_time_s=dp.spin_time_s, + annealing_temp_c=dp.annealing_temp_c, + annealing_time_min=dp.annealing_time_min, + annealing_atmosphere=dp.annealing_atmosphere, + film_thickness_nm=dp.film_thickness_nm, + ) + source_quote = dp.raw_text_snippet or f"Legacy extraction from {dp.source_table_or_figure or dp.source_paper_id}." + if len(source_quote.strip()) < 10: + source_quote = f"Legacy extraction reference: {dp.source_paper_id}." + + for property_name, value, unit in measurements: + if value is None: + continue + out.append( + ContextualizedValue( + polymer_name=dp.polymer_name, + dopant=dp.dopant, + dopant_ratio=dp.dopant_ratio, + property_name=property_name, + raw_value=str(value), + raw_unit=unit, + conditions=conditions, + source_quote=source_quote, + source_location=dp.source_table_or_figure, + extraction_confidence=dp.extraction_confidence, + quality_tier=dp.quality_tier, + ) + ) + return out + + +def _paper_from_row(row: Dict[str, Any]) -> PaperMetadata: + try: + source = PaperSource(row["source"]) + except Exception: + source = PaperSource.UNKNOWN + return PaperMetadata( + id=row["external_paper_id"], + title=row["title"], + authors=json.loads(row.get("authors_json") or "[]"), + year=row.get("year"), + doi=row.get("doi"), + abstract=row.get("abstract"), + venue=row.get("venue"), + citation_count=row.get("citation_count"), + is_open_access=bool(row.get("is_open_access")) if row.get("is_open_access") is not None else None, + source=source, + url=row.get("url"), + landing_url=row.get("landing_url"), + pdf_url=row.get("pdf_url"), + pdf_path=row.get("pdf_path"), + match_reasons=json.loads(row.get("match_reasons_json") or "[]"), + ) + + +@dataclass +class LiteraturePipeline: + db_path: str = "data/app.db" + + def __post_init__(self) -> None: + db = get_database(self.db_path) + self.projects = ProjectRepo(db) + self.papers = PaperRepo(db) + self.runs = RunRepo(db) + self.points = DataPointRepo(db) + self.extraction_jobs = ExtractionJobRepo(db) + self.query_sessions = QuerySessionRepo(db) + self.query_intent = QueryIntentService(self.query_sessions) + self.standardizer = UnitStandardizer() + self.quality = QualityAssessor() + + def ensure_default_project(self, name: str = "Production Literature") -> Dict[str, Any]: + return self.projects.ensure_project( + name=name, + description="Persistent production literature evidence staging for the platform UI.", + ) + + def project_name_for_spec(self, spec: LiteratureQuerySpec) -> str: + base = spec.polymer_name or spec.user_query or "query" + base = re.sub(r"\s+", " ", base).strip()[:80] + prop = normalize_property_key(spec.property_key) if spec.property_key else "none" + return f"Literature::{spec.mode.value}::{base}::{prop}" + + def create_project( + self, + name: str, + description: str = "", + default_model_provider: str = "openai_compatible", + default_model_name: str = "gpt-oss:latest", + ) -> Dict[str, Any]: + return self.projects.create_project( + name=name, + description=description, + default_model_provider=default_model_provider, + default_model_name=default_model_name, + ) + + def build_search_query(self, spec: LiteratureQuerySpec) -> str: + terms: List[str] = [] + if spec.mode.value == "material-first" and spec.polymer_name: + terms.append(spec.polymer_name) + elif spec.mode.value == "property-first" and spec.property_key: + terms.append(property_display_name(spec.property_key)) + if spec.user_query: + terms.append(spec.user_query) + if spec.property_key and property_display_name(spec.property_key) not in terms: + terms.append(property_display_name(spec.property_key)) + return " ".join(term for term in terms if term).strip() + + def resolve_target_properties( + self, + spec: LiteratureQuerySpec, + *, + extra_query_text: str = "", + ) -> List[str]: + keys: List[str] = [] + if spec.property_key: + normalized = normalize_property_key(spec.property_key) + if normalized: + keys.append(normalized) + for candidate in detect_property_keys(" ".join([spec.user_query, extra_query_text])): + if candidate not in keys: + keys.append(candidate) + return keys + + def _paper_match_reasons(self, paper: PaperMetadata, spec: LiteratureQuerySpec) -> List[str]: + text = " ".join(filter(None, [paper.title, paper.abstract or "", paper.venue or ""])).lower() + reasons: List[str] = [] + if spec.polymer_name and spec.polymer_name.lower() in text: + reasons.append(f"Material match: {spec.polymer_name}") + if spec.property_key: + display = property_display_name(spec.property_key) + if display.split(" (")[0].lower() in text: + reasons.append(f"Property match: {display}") + if not reasons and spec.user_query: + query_terms = [tok for tok in spec.user_query.lower().split() if len(tok) > 3] + overlaps = [tok for tok in query_terms if tok in text] + if overlaps: + reasons.append("Query overlap: " + ", ".join(sorted(set(overlaps[:4])))) + if not reasons: + reasons.append("Relevant to the search query") + return reasons + + def _paper_background_status(self, paper_row: Dict[str, Any], evidence_count: int, job_status: Optional[str]) -> str: + if evidence_count > 0: + return "evidence_staged" + if job_status in {"running", "pending"}: + return "extracting" + if job_status == "skipped": + return "extraction_skipped" + if job_status == "failed": + return "extraction_failed" + status = paper_row.get("download_status") or "discovered" + if status in {"downloaded", "uploaded_manual"}: + return "pdf_ready" + return status + + def run_discovery( + self, + project_id: str, + query: str, + limit: int = 10, + *, + spec: Optional[LiteratureQuerySpec] = None, + ) -> List[Dict[str, Any]]: + if PaperDiscoveryAgent is None: + raise RuntimeError("Literature discovery dependencies are not installed.") + discovery = PaperDiscoveryAgent() + papers = discovery.discover(query, limit_per_source=limit) + if spec is not None: + for paper in papers: + paper.match_reasons = self._paper_match_reasons(paper, spec) + rows = [self.papers.upsert_from_metadata(project_id, p) for p in papers] + return rows + + def run_retrieval(self, project_id: str, paper_rows: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + if PDFRetriever is None: + raise RuntimeError("Literature retrieval dependencies are not installed.") + retriever = PDFRetriever() + metadata = [_paper_from_row(r) for r in paper_rows] + out = retriever.retrieve_batch(metadata) + + updated: List[Dict[str, Any]] = [] + for paper in out: + row = self.papers.update_download_result( + project_id, + paper.id, + pdf_path=paper.pdf_path, + error=None if paper.pdf_path else "automatic download failed", + ) + if row: + updated.append(row) + return updated + + def _prepare_full_text(self, paper: PaperMetadata, use_full_text: bool) -> PaperMetadata: + if use_full_text and paper.pdf_path: + text = extract_text_from_pdf(paper.pdf_path, max_pages=12) + if text: + paper.full_text = f"Title: {paper.title}\n\n{text}" + return paper + + def _normalize_points(self, points: List[ContextualizedValue]) -> List[ContextualizedValue]: + staged: List[ContextualizedValue] = [] + for point in points: + point.property_name = normalize_property_key(point.property_name) or point.property_name + result = self.standardizer.standardize(point.property_name, point.raw_value, point.raw_unit) + if result.success: + point.standardized_value = result.value + point.standardized_unit = result.unit + else: + point.standardization_error = result.error + continue + + is_valid, _ = self.quality.validate_contextual_value(point) + if not is_valid: + continue + point.quality_tier = self.quality.assess_contextual_quality(point) + staged.append(point) + return staged + + def run_extraction( + self, + project_id: str, + run_id: Optional[str], + paper_rows: Iterable[Dict[str, Any]], + *, + strategy: str = "simple", + model_provider: str = "openai_compatible", + model_name: Optional[str] = None, + extractor_version: str = "production-v1", + use_full_text: bool = True, + target_properties: Optional[List[str]] = None, + extraction_instructions: str = "", + canonical_smiles: Optional[str] = None, + force: bool = False, + ) -> Dict[str, Any]: + if ContextualizedExtractor is None or DataExtractor is None: + raise RuntimeError("Literature extraction dependencies are not installed.") + paper_rows = list(paper_rows) + contextual = ContextualizedExtractor( + model_id=model_name, + target_properties=target_properties, + extra_instructions=extraction_instructions, + ) + legacy = DataExtractor( + strategy=strategy, + target_properties=target_properties, + extra_instructions=extraction_instructions, + ) + contextual_ready = contextual.is_configured() + legacy_ready = legacy.can_attempt_extraction() + + if not contextual_ready and not legacy_ready: + skip_message = "Structured extraction skipped: no LLM or PageIndex backend is configured." + for row in paper_rows: + self.extraction_jobs.upsert_job( + project_id, + row["id"], + extractor_version, + model_provider=model_provider, + model_name=model_name, + status="skipped", + stats={"skip_reason": "no_extraction_backend"}, + error_message=skip_message, + ) + return { + "papers_extracted": 0, + "points_inserted": 0, + "fallback_papers": 0, + "skipped_completed": 0, + "skipped_unconfigured": len(paper_rows), + "skip_reason": "no_extraction_backend", + } + + inserted = 0 + paper_count = 0 + fallback_count = 0 + skipped_completed = 0 + skipped_unconfigured = 0 + + for row in paper_rows: + existing_job = self.extraction_jobs.get_job(project_id, row["id"], extractor_version) + if existing_job and existing_job.get("status") == "completed" and not force: + skipped_completed += 1 + continue + + job = self.extraction_jobs.upsert_job( + project_id, + row["id"], + extractor_version, + model_provider=model_provider, + model_name=model_name, + status="running", + ) + paper = _paper_from_row(row) + paper = self._prepare_full_text(paper, use_full_text=use_full_text) + + result = None + points: List[ContextualizedValue] = [] + last_error_message: Optional[str] = None + if contextual_ready: + result = contextual.extract_from_paper(paper, use_full_text=use_full_text) + last_error_message = result.error_message + if result and result.success and result.data_points: + points = [p for p in result.data_points if isinstance(p, ContextualizedValue)] + + if not points and legacy_ready: + fallback_count += 1 + legacy_results = legacy.extract_from_papers([paper], use_full_text=use_full_text) + for legacy_result in legacy_results: + if getattr(legacy_result, "error_message", None): + last_error_message = getattr(legacy_result, "error_message", None) + if getattr(legacy_result, "success", False): + for point in getattr(legacy_result, "data_points", []) or []: + if isinstance(point, ContextualizedValue): + points.append(point) + elif isinstance(point, PolymerDataPoint): + points.extend(_legacy_point_to_contextual(point)) + + staged = self._normalize_points(points) + if staged: + inserted += self.points.insert_points( + project_id, + row["id"], + job["id"], + staged, + extractor_version=extractor_version, + extraction_model=model_name, + canonical_smiles=canonical_smiles, + ) + paper_count += 1 + self.extraction_jobs.upsert_job( + project_id, + row["id"], + extractor_version, + model_provider=model_provider, + model_name=model_name, + status="completed", + stats={"inserted": len(staged)}, + ) + else: + job_status = "skipped" if is_expected_skip_error(last_error_message) else "completed" + if job_status == "skipped": + skipped_unconfigured += 1 + self.extraction_jobs.upsert_job( + project_id, + row["id"], + extractor_version, + model_provider=model_provider, + model_name=model_name, + status=job_status, + stats={"inserted": 0}, + error_message=last_error_message or (None if points else "no_valid_points"), + ) + + return { + "papers_extracted": paper_count, + "points_inserted": inserted, + "fallback_papers": fallback_count, + "skipped_completed": skipped_completed, + "skipped_unconfigured": skipped_unconfigured, + } + + def run_quality(self, project_id: str) -> Dict[str, Any]: + rows = self.points.list_evidence(project_id) + total = len(rows) + gold = sum(1 for row in rows if row.get("quality_tier") == "gold") + silver = sum(1 for row in rows if row.get("quality_tier") == "silver") + bronze = sum(1 for row in rows if row.get("quality_tier") == "bronze") + pending = sum(1 for row in rows if row.get("review_status") == "pending") + approved = sum(1 for row in rows if row.get("review_status") == "approved") + return { + "total_points": total, + "gold_count": gold, + "silver_count": silver, + "bronze_count": bronze, + "pending_review": pending, + "approved_points": approved, + } + + def run_full_pipeline( + self, + project_id: str, + query: str, + limit: int, + *, + strategy: str = "simple", + model_provider: str = "openai_compatible", + model_name: str = "gpt-oss:latest", + use_full_text: bool = True, + target_properties: Optional[List[str]] = None, + extraction_instructions: str = "", + extractor_version: str = "production-v1", + ) -> Dict[str, Any]: + run = self.runs.create_run( + project_id=project_id, + query=query, + strategy=strategy, + model_provider=model_provider, + model_name=model_name, + ) + try: + discovered = self.run_discovery(project_id, query=query, limit=limit) + retrieved = self.run_retrieval(project_id, discovered) + extraction_stats = self.run_extraction( + project_id, + run["id"], + retrieved[: min(10, len(retrieved))], + strategy=strategy, + model_provider=model_provider, + model_name=model_name, + extractor_version=extractor_version, + use_full_text=use_full_text, + target_properties=target_properties, + extraction_instructions=extraction_instructions, + ) + quality = self.run_quality(project_id) + stats = { + "papers_discovered": len(discovered), + "papers_after_retrieval": len(retrieved), + **extraction_stats, + **quality, + } + self.runs.finish_run(run["id"], status="completed", stats=stats) + return {"run_id": run["id"], "status": "completed", "stats": stats} + except Exception as exc: + logger.exception("Pipeline failed") + self.runs.finish_run(run["id"], status="failed", error_message=str(exc), stats={}) + return {"run_id": run["id"], "status": "failed", "error": str(exc)} + + def search(self, spec: LiteratureQuerySpec) -> Dict[str, Any]: + project = self.projects.get_project(spec.project_id) if spec.project_id else None + if not project: + project = self.projects.ensure_project( + name=self.project_name_for_spec(spec), + description="Auto-managed literature project for a production UI query.", + ) + query_text = self.build_search_query(spec) + query_session = self.query_intent.analyze_and_store(project["id"], query_text) + discovered = self.run_discovery(project["id"], query_text, limit=spec.result_limit, spec=spec) + return { + "project_id": project["id"], + "query_text": query_text, + "query_session_id": query_session["id"], + "suggestions": json.loads(query_session.get("suggestions_json") or "[]"), + "clarification_required": bool(query_session.get("clarification_required")), + "papers": self.get_paper_cards(project["id"]), + } + + def process_top_papers( + self, + spec: LiteratureQuerySpec, + *, + model_provider: str = "openai_compatible", + model_name: str = "gpt-oss:latest", + extractor_version: str = "production-v1", + ) -> Dict[str, Any]: + if spec.project_id: + project_id = spec.project_id + else: + project_id = self.projects.ensure_project( + name=self.project_name_for_spec(spec), + description="Auto-managed literature project for a production UI query.", + )["id"] + target_properties = self.resolve_target_properties(spec) + papers = self.papers.list_papers(project_id)[: spec.top_k_extract] + retrieved = self.run_retrieval(project_id, papers) + extraction_stats = {"papers_extracted": 0, "points_inserted": 0, "fallback_papers": 0, "skipped_completed": 0} + if target_properties: + extraction_stats = self.run_extraction( + project_id, + run_id=None, + paper_rows=retrieved[: spec.top_k_extract], + strategy="pageindex", + model_provider=model_provider, + model_name=model_name, + extractor_version=extractor_version, + target_properties=target_properties, + canonical_smiles=spec.canonical_smiles, + ) + return { + "project_id": project_id, + "target_properties": target_properties, + "papers_processed": len(retrieved[: spec.top_k_extract]), + **extraction_stats, + } + + def get_paper_cards(self, project_id: str, limit: Optional[int] = None) -> List[PaperCardResult]: + papers = self.papers.list_papers(project_id) + if limit is not None: + papers = papers[:limit] + evidence_rows = self.points.list_evidence(project_id) + evidence_count: Dict[str, int] = {} + for row in evidence_rows: + evidence_count[row["paper_id"]] = evidence_count.get(row["paper_id"], 0) + 1 + jobs: Dict[str, Dict[str, Any]] = {} + for row in self.extraction_jobs.list_jobs(project_id): + jobs.setdefault(row["paper_id"], row) + + cards: List[PaperCardResult] = [] + for row in papers: + background_status = self._paper_background_status( + row, + evidence_count.get(row["id"], 0), + jobs.get(row["id"], {}).get("status"), + ) + cards.append( + PaperCardResult( + paper_id=row["id"], + title=row["title"], + year=row.get("year"), + venue=row.get("venue"), + doi=row.get("doi"), + landing_url=row.get("landing_url") or row.get("url"), + pdf_url=row.get("pdf_url"), + is_open_access=bool(row.get("is_open_access")), + match_reasons=json.loads(row.get("match_reasons_json") or "[]"), + background_status=background_status, + ) + ) + return cards + + def get_support_summary( + self, + project_id: str, + *, + material_name: Optional[str] = None, + property_key: Optional[str] = None, + ) -> LiteratureSupportSummary: + cards = self.get_paper_cards(project_id) + matched = len(cards) + oa = sum(1 for card in cards if card.is_open_access) + summary = self.points.support_summary( + project_id, + material_name=material_name, + property_key=property_key, + matched_paper_count=matched, + oa_paper_count=oa, + ) + return LiteratureSupportSummary(**summary) + + def list_reviewer_queue(self, project_id: str, limit: int = 50) -> List[Dict[str, Any]]: + return self.points.list_pending_review(project_id, limit=limit) + + def apply_review_action( + self, + evidence_id: str, + *, + action: str, + reviewer_note: Optional[str] = None, + edited_payload: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + status_map = { + "approve": "approved", + "edit_and_approve": "approved", + "reject": "rejected", + } + validation_status = status_map.get(action, action) + return self.points.update_review( + evidence_id, + validation_status=validation_status, + reviewer_note=reviewer_note, + edited_payload=edited_payload, + action=action, + ) + + def only_simulation_like(self, project_id: str) -> bool: + papers = self.papers.list_papers(project_id) + if not papers: + return False + for row in papers: + text = " ".join(filter(None, [row.get("title"), row.get("abstract") or ""])).lower() + if not any(keyword in text for keyword in SIMULATION_KEYWORDS): + return False + return True diff --git a/src/literature_service/query_intent.py b/src/literature_service/query_intent.py new file mode 100644 index 0000000000000000000000000000000000000000..c88e1c5b59679a635765d51c0852980799ae0239 --- /dev/null +++ b/src/literature_service/query_intent.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict + +from literature.clarifier import ClarifierAgent + +from .repository import QuerySessionRepo + + +@dataclass +class QueryIntentService: + repo: QuerySessionRepo + clarifier: ClarifierAgent = ClarifierAgent() + + def analyze_and_store(self, project_id: str, query: str) -> Dict[str, Any]: + analysis = self.clarifier.analyze(query) + return self.repo.create_session( + project_id=project_id, + original_query=query, + suggestions=analysis.suggestions, + clarification_required=analysis.clarification_required, + payload=analysis.to_payload(), + status=analysis.status, + ) diff --git a/src/literature_service/repository.py b/src/literature_service/repository.py new file mode 100644 index 0000000000000000000000000000000000000000..1d07967d656b421d31b9b1c59d0090f3632bd1ec --- /dev/null +++ b/src/literature_service/repository.py @@ -0,0 +1,834 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional +from uuid import uuid4 + +from literature.schemas import ContextualizedValue, LiteratureEvidenceRecord, PaperMetadata + +from .db import Database, get_database, utc_now_iso + + +def _id(prefix: str) -> str: + return f"{prefix}_{uuid4().hex}" + + +def _row_to_dict(row) -> Dict[str, Any]: + return dict(row) if row is not None else {} + + +@dataclass +class BaseRepo: + db: Database + + @classmethod + def with_default_db(cls): + return cls(get_database()) + + +class ProjectRepo(BaseRepo): + def create_project( + self, + name: str, + description: str = "", + default_model_provider: str = "openai_compatible", + default_model_name: str = "gpt-oss:latest", + target_properties: Optional[List[str]] = None, + extraction_instructions: str = "", + ) -> Dict[str, Any]: + pid = _id("proj") + now = utc_now_iso() + with self.db.connect() as conn: + conn.execute( + """ + INSERT INTO projects( + id, name, description, status, default_model_provider, default_model_name, + target_properties_json, extraction_instructions, created_at, updated_at + ) VALUES (?, ?, ?, 'active', ?, ?, ?, ?, ?, ?) + """, + ( + pid, + name, + description, + default_model_provider, + default_model_name, + json.dumps(target_properties or []), + extraction_instructions, + now, + now, + ), + ) + row = conn.execute("SELECT * FROM projects WHERE id = ?", (pid,)).fetchone() + return _row_to_dict(row) + + def ensure_project( + self, + name: str, + *, + description: str = "", + default_model_provider: str = "openai_compatible", + default_model_name: str = "gpt-oss:latest", + ) -> Dict[str, Any]: + with self.db.connect() as conn: + row = conn.execute( + "SELECT * FROM projects WHERE name = ? AND status != 'archived' ORDER BY created_at ASC LIMIT 1", + (name,), + ).fetchone() + if row: + return _row_to_dict(row) + return self.create_project( + name=name, + description=description, + default_model_provider=default_model_provider, + default_model_name=default_model_name, + ) + + def list_projects(self, include_archived: bool = False) -> List[Dict[str, Any]]: + where = "" if include_archived else "WHERE status != 'archived'" + with self.db.connect() as conn: + rows = conn.execute(f"SELECT * FROM projects {where} ORDER BY created_at DESC").fetchall() + return [_row_to_dict(r) for r in rows] + + def get_project(self, project_id: str) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() + return _row_to_dict(row) if row else None + + def update_project( + self, + project_id: str, + *, + name: Optional[str] = None, + description: Optional[str] = None, + status: Optional[str] = None, + default_model_provider: Optional[str] = None, + default_model_name: Optional[str] = None, + target_properties: Optional[List[str]] = None, + extraction_instructions: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + current = self.get_project(project_id) + if not current: + return None + now = utc_now_iso() + with self.db.connect() as conn: + conn.execute( + """ + UPDATE projects + SET name = ?, description = ?, status = ?, default_model_provider = ?, default_model_name = ?, + target_properties_json = ?, extraction_instructions = ?, updated_at = ? + WHERE id = ? + """, + ( + name if name is not None else current["name"], + description if description is not None else current["description"], + status if status is not None else current["status"], + default_model_provider if default_model_provider is not None else current["default_model_provider"], + default_model_name if default_model_name is not None else current["default_model_name"], + json.dumps(target_properties) if target_properties is not None else current.get("target_properties_json", "[]"), + extraction_instructions if extraction_instructions is not None else current.get("extraction_instructions", ""), + now, + project_id, + ), + ) + row = conn.execute("SELECT * FROM projects WHERE id = ?", (project_id,)).fetchone() + return _row_to_dict(row) if row else None + + +class PaperRepo(BaseRepo): + def upsert_from_metadata(self, project_id: str, paper: PaperMetadata) -> Dict[str, Any]: + now = utc_now_iso() + landing_url = paper.landing_url or paper.url + match_reasons_json = json.dumps(paper.match_reasons or []) + with self.db.connect() as conn: + existing = conn.execute( + "SELECT id, pdf_path, download_status FROM papers WHERE project_id = ? AND external_paper_id = ?", + (project_id, paper.id), + ).fetchone() + if existing: + pid = existing["id"] + conn.execute( + """ + UPDATE papers + SET source = ?, title = ?, authors_json = ?, year = ?, doi = ?, abstract = ?, venue = ?, + citation_count = ?, is_open_access = ?, url = ?, landing_url = ?, pdf_url = ?, + pdf_path = COALESCE(?, pdf_path), match_reasons_json = ?, download_status = ?, updated_at = ? + WHERE id = ? + """, + ( + paper.source.value, + paper.title, + json.dumps(paper.authors or []), + paper.year, + paper.doi, + paper.abstract, + paper.venue, + paper.citation_count, + 1 if paper.is_open_access else 0, + paper.url, + landing_url, + paper.pdf_url, + paper.pdf_path, + match_reasons_json, + "downloaded" if paper.pdf_path else existing["download_status"], + now, + pid, + ), + ) + else: + pid = _id("paper") + conn.execute( + """ + INSERT INTO papers( + id, project_id, external_paper_id, source, title, authors_json, year, doi, abstract, venue, + citation_count, is_open_access, url, landing_url, pdf_url, pdf_path, match_reasons_json, + download_status, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + pid, + project_id, + paper.id, + paper.source.value, + paper.title, + json.dumps(paper.authors or []), + paper.year, + paper.doi, + paper.abstract, + paper.venue, + paper.citation_count, + 1 if paper.is_open_access else 0, + paper.url, + landing_url, + paper.pdf_url, + paper.pdf_path, + match_reasons_json, + "downloaded" if paper.pdf_path else "discovered", + now, + now, + ), + ) + row = conn.execute("SELECT * FROM papers WHERE id = ?", (pid,)).fetchone() + return _row_to_dict(row) + + def update_download_result( + self, + project_id: str, + external_paper_id: str, + *, + pdf_path: Optional[str], + error: Optional[str], + ) -> Optional[Dict[str, Any]]: + status = "downloaded" if pdf_path else "download_failed" + with self.db.connect() as conn: + conn.execute( + """ + UPDATE papers + SET pdf_path = ?, download_status = ?, download_error = ?, updated_at = ? + WHERE project_id = ? AND external_paper_id = ? + """, + (pdf_path, status, error, utc_now_iso(), project_id, external_paper_id), + ) + row = conn.execute( + "SELECT * FROM papers WHERE project_id = ? AND external_paper_id = ?", + (project_id, external_paper_id), + ).fetchone() + return _row_to_dict(row) if row else None + + def get_paper(self, project_id: str, paper_id: str) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + row = conn.execute( + "SELECT * FROM papers WHERE project_id = ? AND id = ?", + (project_id, paper_id), + ).fetchone() + return _row_to_dict(row) if row else None + + def list_papers(self, project_id: str) -> List[Dict[str, Any]]: + with self.db.connect() as conn: + rows = conn.execute( + "SELECT * FROM papers WHERE project_id = ? ORDER BY updated_at DESC, created_at DESC", + (project_id,), + ).fetchall() + return [_row_to_dict(r) for r in rows] + + def list_papers_by_ids(self, project_id: str, paper_ids: Iterable[str]) -> List[Dict[str, Any]]: + ids = list(paper_ids) + if not ids: + return [] + placeholders = ",".join("?" for _ in ids) + with self.db.connect() as conn: + rows = conn.execute( + f"SELECT * FROM papers WHERE project_id = ? AND id IN ({placeholders})", + [project_id, *ids], + ).fetchall() + return [_row_to_dict(r) for r in rows] + + def list_failed_papers(self, project_id: str) -> List[Dict[str, Any]]: + with self.db.connect() as conn: + rows = conn.execute( + """ + SELECT * FROM papers + WHERE project_id = ? AND download_status = 'download_failed' + ORDER BY updated_at DESC + """, + (project_id,), + ).fetchall() + return [_row_to_dict(r) for r in rows] + + def mark_uploaded_manual(self, project_id: str, paper_id: str, pdf_path: str) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + conn.execute( + """ + UPDATE papers + SET pdf_path = ?, download_status = 'uploaded_manual', download_error = NULL, updated_at = ? + WHERE project_id = ? AND id = ? + """, + (pdf_path, utc_now_iso(), project_id, paper_id), + ) + row = conn.execute( + "SELECT * FROM papers WHERE project_id = ? AND id = ?", + (project_id, paper_id), + ).fetchone() + return _row_to_dict(row) if row else None + + +class RunRepo(BaseRepo): + def create_run( + self, + project_id: str, + query: str, + strategy: str, + model_provider: str, + model_name: str, + ) -> Dict[str, Any]: + rid = _id("run") + with self.db.connect() as conn: + conn.execute( + """ + INSERT INTO extraction_runs( + id, project_id, query, strategy, model_provider, model_name, status, started_at + ) VALUES (?, ?, ?, ?, ?, ?, 'running', ?) + """, + (rid, project_id, query, strategy, model_provider, model_name, utc_now_iso()), + ) + row = conn.execute("SELECT * FROM extraction_runs WHERE id = ?", (rid,)).fetchone() + return _row_to_dict(row) + + def finish_run( + self, + run_id: str, + *, + status: str, + stats: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + conn.execute( + """ + UPDATE extraction_runs + SET status = ?, ended_at = ?, stats_json = ?, error_message = ? + WHERE id = ? + """, + (status, utc_now_iso(), json.dumps(stats or {}), error_message, run_id), + ) + row = conn.execute("SELECT * FROM extraction_runs WHERE id = ?", (run_id,)).fetchone() + return _row_to_dict(row) if row else None + + def list_runs(self, project_id: str, limit: int = 20) -> List[Dict[str, Any]]: + with self.db.connect() as conn: + rows = conn.execute( + "SELECT * FROM extraction_runs WHERE project_id = ? ORDER BY started_at DESC LIMIT ?", + (project_id, limit), + ).fetchall() + out = [] + for row in rows: + item = _row_to_dict(row) + item["stats_json"] = json.loads(item.get("stats_json") or "{}") + out.append(item) + return out + + +class ExtractionJobRepo(BaseRepo): + def upsert_job( + self, + project_id: str, + paper_id: str, + extractor_version: str, + *, + model_provider: Optional[str] = None, + model_name: Optional[str] = None, + status: str = "pending", + stats: Optional[Dict[str, Any]] = None, + error_message: Optional[str] = None, + ) -> Dict[str, Any]: + now = utc_now_iso() + with self.db.connect() as conn: + existing = conn.execute( + """ + SELECT * FROM literature_extraction_jobs + WHERE project_id = ? AND paper_id = ? AND extractor_version = ? + """, + (project_id, paper_id, extractor_version), + ).fetchone() + if existing: + jid = existing["id"] + conn.execute( + """ + UPDATE literature_extraction_jobs + SET model_provider = COALESCE(?, model_provider), + model_name = COALESCE(?, model_name), + status = ?, stats_json = ?, error_message = ?, updated_at = ? + WHERE id = ? + """, + ( + model_provider, + model_name, + status, + json.dumps(stats or {}), + error_message, + now, + jid, + ), + ) + else: + jid = _id("xjob") + conn.execute( + """ + INSERT INTO literature_extraction_jobs( + id, project_id, paper_id, extractor_version, model_provider, model_name, + status, stats_json, error_message, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + jid, + project_id, + paper_id, + extractor_version, + model_provider, + model_name, + status, + json.dumps(stats or {}), + error_message, + now, + now, + ), + ) + row = conn.execute("SELECT * FROM literature_extraction_jobs WHERE id = ?", (jid,)).fetchone() + return _row_to_dict(row) + + def get_job(self, project_id: str, paper_id: str, extractor_version: str) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + row = conn.execute( + """ + SELECT * FROM literature_extraction_jobs + WHERE project_id = ? AND paper_id = ? AND extractor_version = ? + """, + (project_id, paper_id, extractor_version), + ).fetchone() + return _row_to_dict(row) if row else None + + def list_jobs(self, project_id: str) -> List[Dict[str, Any]]: + with self.db.connect() as conn: + rows = conn.execute( + "SELECT * FROM literature_extraction_jobs WHERE project_id = ? ORDER BY updated_at DESC", + (project_id,), + ).fetchall() + out = [] + for row in rows: + item = _row_to_dict(row) + item["stats_json"] = json.loads(item.get("stats_json") or "{}") + out.append(item) + return out + + +class DataPointRepo(BaseRepo): + def insert_points( + self, + project_id: str, + paper_id: str, + run_id: Optional[str], + points: Iterable[ContextualizedValue], + *, + extractor_version: str = "production-v1", + extraction_model: Optional[str] = None, + canonical_smiles: Optional[str] = None, + ) -> int: + now = utc_now_iso() + rows = [] + for point in points: + method = point.conditions.measurement_method + rows.append( + ( + _id("evi"), + project_id, + paper_id, + run_id, + point.polymer_name, + canonical_smiles, + point.property_name, + point.raw_value, + point.raw_unit, + point.standardized_value, + point.standardized_unit, + json.dumps(point.conditions.to_dict()), + method, + point.source_quote, + point.source_location, + extractor_version, + extraction_model, + point.extraction_confidence, + point.quality_tier.value, + "pending", + None, + None, + now, + now, + ) + ) + + if not rows: + return 0 + + with self.db.connect() as conn: + conn.executemany( + """ + INSERT INTO literature_evidence( + id, project_id, paper_id, extraction_job_id, material_name, canonical_smiles, property_key, + raw_value, raw_unit, standardized_value, standardized_unit, conditions_json, method, + evidence_quote, evidence_location, extractor_version, extraction_model, extraction_confidence, + quality_tier, review_status, reviewer_note, edited_payload_json, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + rows, + ) + return len(rows) + + def list_points( + self, + project_id: str, + validation_status: Optional[str] = None, + *, + property_name: Optional[str] = None, + material_name: Optional[str] = None, + paper_id: Optional[str] = None, + ) -> List[Dict[str, Any]]: + where = ["project_id = ?"] + params: List[Any] = [project_id] + if validation_status: + where.append("validation_status = ?") + params.append(validation_status) + if property_name: + where.append("property_name = ?") + params.append(property_name) + if material_name: + where.append("LOWER(polymer_name) LIKE ?") + params.append(f"%{material_name.lower()}%") + if paper_id: + where.append("paper_id = ?") + params.append(paper_id) + query = "SELECT * FROM v_literature_points_flat WHERE " + " AND ".join(where) + " ORDER BY created_at DESC" + with self.db.connect() as conn: + rows = conn.execute(query, params).fetchall() + return [_row_to_dict(r) for r in rows] + + def list_evidence( + self, + project_id: str, + *, + review_status: Optional[str] = None, + property_key: Optional[str] = None, + material_name: Optional[str] = None, + paper_id: Optional[str] = None, + limit: Optional[int] = None, + ) -> List[Dict[str, Any]]: + where = ["project_id = ?"] + params: List[Any] = [project_id] + if review_status: + where.append("review_status = ?") + params.append(review_status) + if property_key: + where.append("property_key = ?") + params.append(property_key) + if material_name: + where.append("LOWER(material_name) LIKE ?") + params.append(f"%{material_name.lower()}%") + if paper_id: + where.append("paper_id = ?") + params.append(paper_id) + sql = "SELECT * FROM v_literature_evidence_flat WHERE " + " AND ".join(where) + " ORDER BY updated_at DESC" + if limit: + sql += " LIMIT ?" + params.append(limit) + with self.db.connect() as conn: + rows = conn.execute(sql, params).fetchall() + return [_row_to_dict(r) for r in rows] + + def get_evidence(self, evidence_id: str) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + row = conn.execute( + "SELECT * FROM v_literature_evidence_flat WHERE id = ?", + (evidence_id,), + ).fetchone() + return _row_to_dict(row) if row else None + + def list_pending_review(self, project_id: str, limit: int = 50) -> List[Dict[str, Any]]: + return self.list_evidence(project_id, review_status="pending", limit=limit) + + def update_review( + self, + point_id: str, + *, + validation_status: str, + reviewer_note: Optional[str] = None, + edited_payload: Optional[Dict[str, Any]] = None, + action: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + with self.db.connect() as conn: + row = conn.execute("SELECT * FROM literature_evidence WHERE id = ?", (point_id,)).fetchone() + if row: + current = _row_to_dict(row) + current_status = current.get("review_status") + next_payload = json.dumps(edited_payload, ensure_ascii=False) if edited_payload is not None else current.get("edited_payload_json") + next_note = reviewer_note if reviewer_note is not None else current.get("reviewer_note") + if ( + current_status == validation_status + and next_payload == current.get("edited_payload_json") + and next_note == current.get("reviewer_note") + ): + updated = conn.execute("SELECT * FROM literature_evidence WHERE id = ?", (point_id,)).fetchone() + return _row_to_dict(updated) if updated else None + + conn.execute( + """ + UPDATE literature_evidence + SET review_status = ?, reviewer_note = ?, edited_payload_json = ?, updated_at = ? + WHERE id = ? + """, + (validation_status, next_note, next_payload, utc_now_iso(), point_id), + ) + event_id = _id("rev") + conn.execute( + """ + INSERT INTO literature_review_events( + id, project_id, evidence_id, action, from_status, to_status, reviewer_note, payload_json, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + event_id, + current["project_id"], + point_id, + action or validation_status, + current_status, + validation_status, + next_note, + next_payload, + utc_now_iso(), + ), + ) + updated = conn.execute("SELECT * FROM literature_evidence WHERE id = ?", (point_id,)).fetchone() + return _row_to_dict(updated) if updated else None + + # Legacy fallback + conn.execute( + """ + UPDATE contextual_data_points + SET validation_status = ?, reviewer_note = ?, edited_payload_json = ?, updated_at = ? + WHERE id = ? + """, + ( + validation_status, + reviewer_note, + json.dumps(edited_payload) if edited_payload else None, + utc_now_iso(), + point_id, + ), + ) + updated = conn.execute("SELECT * FROM contextual_data_points WHERE id = ?", (point_id,)).fetchone() + return _row_to_dict(updated) if updated else None + + def support_summary( + self, + project_id: str, + *, + material_name: Optional[str] = None, + property_key: Optional[str] = None, + matched_paper_count: Optional[int] = None, + oa_paper_count: Optional[int] = None, + ) -> Dict[str, Any]: + evidence = self.list_evidence( + project_id, + property_key=property_key, + material_name=material_name, + ) + total = len(evidence) + approved = sum(1 for row in evidence if row.get("review_status") == "approved") + experimental = approved > 0 or total > 0 + matched = matched_paper_count if matched_paper_count is not None else 0 + oa = oa_paper_count if oa_paper_count is not None else 0 + paper_component = min(matched, 10) / 10.0 + oa_component = (oa / matched) if matched else 0.0 + evidence_component = min(total, 10) / 10.0 + score = round((0.45 * paper_component + 0.20 * oa_component + 0.35 * evidence_component) * 100) + return { + "matched_paper_count": matched, + "oa_paper_count": oa, + "evidence_record_count": total, + "approved_record_count": approved, + "has_experimental_evidence": experimental, + "literature_support_score": max(0, min(100, score)), + } + + +class ManualUploadRepo(BaseRepo): + def create_upload( + self, + project_id: str, + paper_id: str, + uploaded_filename: str, + stored_path: str, + status: str = "stored", + error_message: Optional[str] = None, + ) -> Dict[str, Any]: + uid = _id("mu") + with self.db.connect() as conn: + conn.execute( + """ + INSERT INTO manual_uploads( + id, project_id, paper_id, uploaded_filename, stored_path, status, error_message, uploaded_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (uid, project_id, paper_id, uploaded_filename, stored_path, status, error_message, utc_now_iso()), + ) + row = conn.execute("SELECT * FROM manual_uploads WHERE id = ?", (uid,)).fetchone() + return _row_to_dict(row) + + +class QuerySessionRepo(BaseRepo): + def create_session( + self, + project_id: str, + original_query: str, + suggestions: List[str], + clarification_required: bool, + payload: Dict[str, Any], + status: str, + ) -> Dict[str, Any]: + sid = _id("qs") + now = utc_now_iso() + with self.db.connect() as conn: + conn.execute( + """ + INSERT INTO query_sessions( + id, project_id, original_query, suggestions_json, clarification_required, + clarification_payload_json, status, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + sid, + project_id, + original_query, + json.dumps(suggestions), + 1 if clarification_required else 0, + json.dumps(payload), + status, + now, + now, + ), + ) + row = conn.execute("SELECT * FROM query_sessions WHERE id = ?", (sid,)).fetchone() + return _row_to_dict(row) + + +class PageIndexRepo(BaseRepo): + def upsert_document( + self, + project_id: str, + paper_id: str, + *, + doc_id: Optional[str], + status: str, + error_message: Optional[str] = None, + ) -> Dict[str, Any]: + now = utc_now_iso() + with self.db.connect() as conn: + row = conn.execute( + "SELECT id FROM pageindex_documents WHERE project_id = ? AND paper_id = ?", + (project_id, paper_id), + ).fetchone() + if row: + did = row["id"] + conn.execute( + """ + UPDATE pageindex_documents + SET doc_id = ?, status = ?, last_polled_at = ?, error_message = ? + WHERE id = ? + """, + (doc_id, status, now, error_message, did), + ) + else: + did = _id("pidoc") + conn.execute( + """ + INSERT INTO pageindex_documents( + id, project_id, paper_id, doc_id, status, last_polled_at, error_message, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + (did, project_id, paper_id, doc_id, status, now, error_message, now), + ) + out = conn.execute("SELECT * FROM pageindex_documents WHERE id = ?", (did,)).fetchone() + return _row_to_dict(out) + + def list_documents(self, project_id: str) -> List[Dict[str, Any]]: + with self.db.connect() as conn: + rows = conn.execute( + "SELECT * FROM pageindex_documents WHERE project_id = ? ORDER BY created_at DESC", + (project_id,), + ).fetchall() + return [_row_to_dict(r) for r in rows] + + +class QAMessageRepo(BaseRepo): + def add_message( + self, + project_id: str, + session_id: str, + role: str, + content: str, + doc_ids: Optional[List[str]] = None, + model_provider: Optional[str] = None, + model_name: Optional[str] = None, + ) -> Dict[str, Any]: + mid = _id("qam") + with self.db.connect() as conn: + conn.execute( + """ + INSERT INTO qa_messages( + id, project_id, session_id, role, content, doc_ids_json, model_provider, model_name, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + mid, + project_id, + session_id, + role, + content, + json.dumps(doc_ids or []), + model_provider, + model_name, + utc_now_iso(), + ), + ) + row = conn.execute("SELECT * FROM qa_messages WHERE id = ?", (mid,)).fetchone() + return _row_to_dict(row) + + def list_messages(self, project_id: str, session_id: str) -> List[Dict[str, Any]]: + with self.db.connect() as conn: + rows = conn.execute( + """ + SELECT * FROM qa_messages + WHERE project_id = ? AND session_id = ? + ORDER BY created_at ASC + """, + (project_id, session_id), + ).fetchall() + return [_row_to_dict(r) for r in rows] diff --git a/src/literature_ui.py b/src/literature_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..51cc1e624ad267770c64d9fdf7e330ccba6d005d --- /dev/null +++ b/src/literature_ui.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from typing import Iterable, List, Optional, Tuple + +import pandas as pd +import streamlit as st + +from literature.schemas import LiteratureQuerySpec, LiteratureSupportSummary, PaperCardResult, QueryMode +from src.lookup import PROPERTY_META + + +def property_select_options(include_blank: bool = True) -> Tuple[List[str], dict[str, Optional[str]]]: + label_to_key: dict[str, Optional[str]] = {} + options: List[str] = [] + if include_blank: + options.append("Any property") + label_to_key["Any property"] = None + for key in sorted(PROPERTY_META.keys()): + meta = PROPERTY_META[key] + label = f"{meta['name']} ({meta['unit']})" + options.append(label) + label_to_key[label] = key + return options, label_to_key + + +def render_support_summary(summary: LiteratureSupportSummary, *, only_simulation_like: bool = False) -> None: + c1, c2, c3, c4, c5 = st.columns(5) + c1.metric("Matched papers", summary.matched_paper_count) + c2.metric("OA papers", summary.oa_paper_count) + c3.metric("Evidence records", summary.evidence_record_count) + c4.metric("Approved", summary.approved_record_count) + c5.metric("Support score", summary.literature_support_score) + if only_simulation_like: + st.caption("Current support looks simulation/computation-heavy; no experiment-backed evidence has been staged yet.") + elif summary.has_experimental_evidence: + st.caption("Experiment-backed evidence is available in the staging queue.") + else: + st.caption("No experiment-backed evidence is staged yet.") + + +def render_extraction_runtime_notice(run_meta: Optional[dict]) -> None: + if not run_meta: + return + if run_meta.get("skip_reason") == "no_extraction_backend" or run_meta.get("skipped_unconfigured", 0): + st.info( + "Structured extraction is disabled in this environment. " + "Configure `PAGEINDEX_API_KEY` or a valid LLM backend to stage evidence records." + ) + + +def render_paper_cards(cards: Iterable[PaperCardResult], *, max_cards: Optional[int] = None) -> None: + items = list(cards) + if max_cards is not None: + items = items[:max_cards] + if not items: + st.info("No papers found for this query.") + return + + for card in items: + with st.container(border=True): + header = card.title + year_part = f" ({card.year})" if card.year else "" + st.markdown(f"**{header}{year_part}**") + meta_parts = [part for part in [card.venue, card.doi, card.background_status] if part] + if meta_parts: + st.caption(" | ".join(meta_parts)) + if card.match_reasons: + st.write("Why it matched: " + "; ".join(card.match_reasons)) + links = [] + if card.landing_url: + links.append(f"[Landing page]({card.landing_url})") + if card.pdf_url: + links.append(f"[PDF]({card.pdf_url})") + if links: + st.markdown(" ".join(links)) + + +def render_evidence_table(rows: List[dict], *, title: str = "Staged evidence", max_rows: int = 25) -> None: + st.subheader(title) + if not rows: + st.info("No staged evidence records yet.") + return + df = pd.DataFrame(rows[:max_rows]).copy() + keep = [ + "material_name", "property_key", "raw_value", "raw_unit", + "standardized_value", "standardized_unit", "method", + "quality_tier", "review_status", "paper_title", "evidence_location", + ] + keep = [column for column in keep if column in df.columns] + display = df[keep].copy() + display = display.rename( + columns={ + "material_name": "Material", + "property_key": "Property", + "raw_value": "Raw value", + "raw_unit": "Raw unit", + "standardized_value": "Std. value", + "standardized_unit": "Std. unit", + "method": "Method", + "quality_tier": "Quality", + "review_status": "Review", + "paper_title": "Paper", + "evidence_location": "Location", + } + ) + display.index = range(1, len(display) + 1) + st.dataframe(display, width="stretch") + + +def render_candidate_evidence_explorer( + *, + prefix: str, + out_df: pd.DataFrame, + pipeline, + objective_props: List[str], +) -> None: + st.subheader("Evidence Explorer") + if out_df is None or out_df.empty: + st.info("Run discovery first to inspect literature support for candidates.") + return + if "SMILES" not in out_df.columns: + st.info("Evidence explorer unavailable: missing SMILES column.") + return + + label_to_index = {} + options: List[str] = [] + for idx, row in out_df.head(50).reset_index(drop=True).iterrows(): + polymer_name = row.get("polymer_name") or "Candidate" + label = f"{idx + 1}. {polymer_name} | {row['SMILES'][:40]}" + options.append(label) + label_to_index[label] = idx + + selected_label = st.selectbox("Candidate", options, key=f"{prefix}_candidate") + selected_row = out_df.head(50).reset_index(drop=True).iloc[label_to_index[selected_label]] + property_key = objective_props[0] if objective_props else None + query_name = selected_row.get("polymer_name") or selected_row["SMILES"] + + if st.button("Load literature support", key=f"{prefix}_load_literature"): + spec = LiteratureQuerySpec( + mode=QueryMode.MATERIAL, + user_query=f"{query_name} {' '.join(objective_props)}".strip(), + polymer_name=(selected_row.get("polymer_name") or None), + canonical_smiles=selected_row["SMILES"], + property_key=property_key, + top_k_extract=10, + result_limit=10, + ) + try: + with st.spinner("Searching and staging evidence…"): + result = pipeline.search(spec) + spec = spec.model_copy(update={"project_id": result["project_id"]}) + run_meta = pipeline.process_top_papers(spec) + st.session_state[f"{prefix}_project_id"] = result["project_id"] + st.session_state[f"{prefix}_run_meta"] = run_meta + except Exception as exc: + st.error(str(exc)) + + project_id = st.session_state.get(f"{prefix}_project_id") + if not project_id: + st.caption("Load literature support to attach paper cards and staged evidence to the selected candidate.") + return + + render_extraction_runtime_notice(st.session_state.get(f"{prefix}_run_meta")) + summary = pipeline.get_support_summary( + project_id, + material_name=(selected_row.get("polymer_name") or None), + property_key=property_key, + ) + render_support_summary(summary, only_simulation_like=pipeline.only_simulation_like(project_id)) + render_paper_cards(pipeline.get_paper_cards(project_id), max_cards=5) + evidence_rows = pipeline.points.list_evidence( + project_id, + property_key=property_key, + material_name=(selected_row.get("polymer_name") or None), + limit=10, + ) + render_evidence_table(evidence_rows, title="Candidate evidence", max_rows=10) diff --git a/src/lookup.py b/src/lookup.py new file mode 100644 index 0000000000000000000000000000000000000000..888003549072b3d1884559b7fd0087e24418eb95 --- /dev/null +++ b/src/lookup.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +import pandas as pd +import streamlit as st +from rdkit import Chem +from rdkit import RDLogger + +RDLogger.DisableLog("rdApp.*") + +# ---------------------------- +# Sources (property value files) +# ---------------------------- +SOURCES = ["EXP", "MD", "DFT", "GC"] + +SOURCE_LABELS = { + "EXP": "Experimental", + "MD": "Molecular Dynamics", + "DFT": "Density Functional Theory", + "GC": "Group Contribution", +} + +# ---------------------------- +# PolyInfo metadata file (name/class) +# ---------------------------- +POLYINFO_FILE = "data/POLYINFO.csv" # contains: SMILES, Polymer_Class, Polymer_Name + + +def canonicalize_smiles(smiles: str) -> str | None: + smiles = (smiles or "").strip() + if not smiles: + return None + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return None + return Chem.MolToSmiles(mol, canonical=True) + + +# --- Property meta (full name + unit) --- +PROPERTY_META = { + # Thermal + "tm": {"name": "Melting temperature", "unit": "K"}, + "tg": {"name": "Glass transition temperature", "unit": "K"}, + "td": {"name": "Thermal diffusivity", "unit": "m^2/s"}, + "tc": {"name": "Thermal conductivity", "unit": "W/m·K"}, + "cp": {"name": "Specific heat capacity", "unit": "J/kg·K"}, + # Mechanical + "young": {"name": "Young's modulus", "unit": "GPa"}, + "shear": {"name": "Shear modulus", "unit": "GPa"}, + "bulk": {"name": "Bulk modulus", "unit": "GPa"}, + "poisson": {"name": "Poisson ratio", "unit": "-"}, + # Transport + "visc": {"name": "Viscosity", "unit": "Pa·s"}, + "dif": {"name": "Diffusivity", "unit": "cm^2/s"}, + # Gas permeability + "phe": {"name": "He permeability", "unit": "Barrer"}, + "ph2": {"name": "H2 permeability", "unit": "Barrer"}, + "pco2": {"name": "CO2 permeability", "unit": "Barrer"}, + "pn2": {"name": "N2 permeability", "unit": "Barrer"}, + "po2": {"name": "O2 permeability", "unit": "Barrer"}, + "pch4": {"name": "CH4 permeability", "unit": "Barrer"}, + # Electronic / Optical + "alpha": {"name": "Polarizability", "unit": "a.u."}, + "homo": {"name": "HOMO energy", "unit": "eV"}, + "lumo": {"name": "LUMO energy", "unit": "eV"}, + "bandgap": {"name": "Band gap", "unit": "eV"}, + "mu": {"name": "Dipole moment", "unit": "Debye"}, + "etotal": {"name": "Total electronic energy", "unit": "eV"}, + "ri": {"name": "Refractive index", "unit": "-"}, + "dc": {"name": "Dielectric constant", "unit": "-"}, + "pe": {"name": "Permittivity", "unit": "-"}, + # Structural / Physical + "rg": {"name": "Radius of gyration", "unit": "Å"}, + "rho": {"name": "Density", "unit": "g/cm^3"}, +} + + +@st.cache_data +def load_source_csv(source: str) -> pd.DataFrame: + """ + Loads data/{SOURCE}.csv, normalizes: + - SMILES column -> 'smiles' + - property columns -> lowercase + - adds 'smiles_canon' + """ + path = f"data/{source}.csv" + df = pd.read_csv(path) + + # Normalize SMILES column name + if "SMILES" in df.columns: + df = df.rename(columns={"SMILES": "smiles"}) + elif "smiles" not in df.columns: + raise ValueError(f"{path} missing SMILES column") + + # Normalize property column names to lowercase + rename_map = {c: c.lower() for c in df.columns if c != "smiles"} + df = df.rename(columns=rename_map) + + # Canonicalize SMILES + df["smiles_canon"] = df["smiles"].astype(str).apply(canonicalize_smiles) + df = df.dropna(subset=["smiles_canon"]).reset_index(drop=True) + + return df + + +@st.cache_data +def build_index(df: pd.DataFrame) -> dict[str, int]: + """canonical smiles -> row index (first occurrence)""" + idx: dict[str, int] = {} + for i, s in enumerate(df["smiles_canon"].tolist()): + if s and s not in idx: + idx[s] = i + return idx + + +@st.cache_data +def load_polyinfo_csv() -> pd.DataFrame: + """ + Loads data/POLYINFO.csv with columns: + SMILES, Polymer_Class, Polymer_Name + Adds canonical smiles column 'smiles_canon'. + Returns empty df if file missing. + """ + try: + df = pd.read_csv(POLYINFO_FILE) + except Exception: + return pd.DataFrame(columns=["smiles", "polymer_class", "polymer_name", "smiles_canon"]) + + # Normalize columns + if "SMILES" in df.columns: + df = df.rename(columns={"SMILES": "smiles"}) + elif "smiles" not in df.columns: + # If the file doesn't have a SMILES column as expected, return empty gracefully + return pd.DataFrame(columns=["smiles", "polymer_class", "polymer_name", "smiles_canon"]) + + # Normalize expected meta columns + ren = {} + if "Polymer_Class" in df.columns: + ren["Polymer_Class"] = "polymer_class" + if "Polymer_Name" in df.columns: + ren["Polymer_Name"] = "polymer_name" + df = df.rename(columns=ren) + + # Ensure the columns exist (even if missing in the file) + if "polymer_class" not in df.columns: + df["polymer_class"] = pd.NA + if "polymer_name" not in df.columns: + df["polymer_name"] = pd.NA + + # Canonicalize smiles + df["smiles_canon"] = df["smiles"].astype(str).apply(canonicalize_smiles) + df = df.dropna(subset=["smiles_canon"]).reset_index(drop=True) + + return df + + +@st.cache_data +def load_all_sources(): + """ + Returns dict: + db["EXP"/"MD"/"DFT"/"GC"] = {"df": df, "idx": idx} + db["POLYINFO"] = {"df": df, "idx": idx} + """ + db = {} + for src in SOURCES: + df = load_source_csv(src) + idx = build_index(df) + db[src] = {"df": df, "idx": idx} + + # PolyInfo metadata + pi_df = load_polyinfo_csv() + pi_idx = build_index(pi_df) if not pi_df.empty else {} + db["POLYINFO"] = {"df": pi_df, "idx": pi_idx} + + return db + + +def get_value(db, source: str, smiles_canon: str, prop_key: str): + pack = db[source] + df, idx = pack["df"], pack["idx"] + row_i = idx.get(smiles_canon, None) + if row_i is None: + return None + if prop_key not in df.columns: + return None + val = df.iloc[row_i][prop_key] + if pd.isna(val): + return None + return float(val) + + +def get_polyinfo(db, smiles_canon: str) -> tuple[str | None, str | None]: + """ + Returns (polymer_name, polymer_class) if available, else (None, None). + No 'not available' text here. + """ + pack = db.get("POLYINFO", None) + if pack is None: + return None, None + + df, idx = pack["df"], pack["idx"] + if df is None or df.empty: + return None, None + + row_i = idx.get(smiles_canon, None) + if row_i is None: + return None, None + + name = df.iloc[row_i].get("polymer_name", None) + cls = df.iloc[row_i].get("polymer_class", None) + + # Clean up NA / empty + if pd.isna(name) or str(name).strip() == "": + name = None + else: + name = str(name).strip() + + if pd.isna(cls) or str(cls).strip() == "": + cls = None + else: + cls = str(cls).strip() + + return name, cls diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7677b9500bdd978e79df3ccf45e6d02d1e975461 --- /dev/null +++ b/src/model.py @@ -0,0 +1,312 @@ +# model.py +from __future__ import annotations + +from typing import List, Optional, Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Batch + +from src.conv import build_gnn_encoder, GNNEncoder + + +def get_activation(name: str) -> nn.Module: + name = name.lower() + if name == "relu": + return nn.ReLU() + if name == "gelu": + return nn.GELU() + if name == "silu": + return nn.SiLU() + if name in ("leaky_relu", "lrelu"): + return nn.LeakyReLU(0.1) + raise ValueError(f"Unknown activation: {name}") + + +class FiLM(nn.Module): + """ + Simple FiLM: gamma, beta from condition vector; apply to features as (1+gamma)*h + beta + """ + def __init__(self, feat_dim: int, cond_dim: int): + super().__init__() + self.gamma = nn.Linear(cond_dim, feat_dim) + self.beta = nn.Linear(cond_dim, feat_dim) + + def forward(self, h: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + g = self.gamma(cond) + b = self.beta(cond) + return (1.0 + g) * h + b + + +class TaskHead(nn.Module): + """ + Per-task MLP head. Input is concatenation of [graph_embed, optional task_embed]. + Outputs either a mean only (scalar) or mean+logvar (heteroscedastic). + """ + def __init__( + self, + in_dim: int, + hidden_dim: int = 512, + depth: int = 2, + act: str = "relu", + dropout: float = 0.0, + heteroscedastic: bool = False, + ): + super().__init__() + layers: List[nn.Module] = [] + d = in_dim + for _ in range(depth): + layers.append(nn.Linear(d, hidden_dim)) + layers.append(get_activation(act)) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + d = hidden_dim + out_dim = 2 if heteroscedastic else 1 + layers.append(nn.Linear(d, out_dim)) + self.net = nn.Sequential(*layers) + self.hetero = heteroscedastic + + def forward(self, z: torch.Tensor) -> torch.Tensor: + # returns [B, 1] or [B, 2] where [...,0] is mean and [...,1] is logvar if heteroscedastic + return self.net(z) + + +class MultiTaskMultiFidelityModel(nn.Module): + """ + General multi-task, multi-fidelity GNN. + + - Any number of tasks (properties) via T = len(task_names) + - Any number of fidelities via num_fids + - Fidelity conditioning with an embedding and FiLM on the graph embedding + - Optional task embeddings concatenated into each task head input + - Single forward returning predictions [B, T] (means); if heteroscedastic, also returns log-variances + + Expected input Batch fields (PyG): + - x : [N_nodes, F_node] + - edge_index : [2, N_edges] + - edge_attr : [N_edges, F_edge] (required if gnn_type="gine") + - batch : [N_nodes] + - fid_idx : [B] or [B, 1] long; integer fidelity per graph + + Notes: + - Targets should already be normalized outside the model; apply inverse transform for plots. + - Loss weighting/equal-importance and curriculum happen in the trainer, not here. + """ + + def __init__( + self, + in_dim_node: int, + in_dim_edge: int, + task_names: List[str], + num_fids: int, + gnn_type: Literal["gine", "gin", "gcn"] = "gine", + gnn_emb_dim: int = 256, + gnn_layers: int = 5, + gnn_norm: Literal["batch", "layer", "none"] = "batch", + gnn_readout: Literal["mean", "sum", "max"] = "mean", + gnn_act: str = "relu", + gnn_dropout: float = 0.0, + gnn_residual: bool = True, + # Fidelity conditioning + fid_emb_dim: int = 64, + use_film: bool = True, + # Task conditioning + use_task_embed: bool = True, + task_emb_dim: int = 32, + # Heads + head_hidden: int = 512, + head_depth: int = 2, + head_act: str = "relu", + head_dropout: float = 0.0, + heteroscedastic: bool = False, + # Optional homoscedastic task uncertainty (used in loss, kept here for checkpoint parity) + use_task_uncertainty: bool = False, + # Embedding regularization (used via regularization_loss) + fid_emb_l2: float = 0.0, + task_emb_l2: float = 0.0, + ): + super().__init__() + self.task_names = list(task_names) + self.num_tasks = len(task_names) + self.num_fids = int(num_fids) + self.hetero = heteroscedastic + self.fid_emb_l2 = float(fid_emb_l2) + self.task_emb_l2 = float(task_emb_l2) + self.use_film = use_film + self.use_task_embed = use_task_embed + + # Optional learned homoscedastic uncertainty per task (trainer may use it) + self.use_task_uncertainty = bool(use_task_uncertainty) + if self.use_task_uncertainty: + self.task_log_sigma2 = nn.Parameter(torch.zeros(self.num_tasks)) + else: + self.task_log_sigma2 = None + + # Encoder + self.encoder: GNNEncoder = build_gnn_encoder( + in_dim_node=in_dim_node, + emb_dim=gnn_emb_dim, + num_layers=gnn_layers, + gnn_type=gnn_type, + in_dim_edge=in_dim_edge, + act=gnn_act, + dropout=gnn_dropout, + residual=gnn_residual, + norm=gnn_norm, + readout=gnn_readout, + ) + + # Fidelity embedding + FiLM + self.fid_embed = nn.Embedding(self.num_fids, fid_emb_dim) if fid_emb_dim > 0 else None + self.film = FiLM(gnn_emb_dim, fid_emb_dim) if (use_film and fid_emb_dim > 0) else None + + # --- Compute the true feature dim sent to heads --- + # If FiLM is ON: g stays [B, gnn_emb_dim] + # If FiLM is OFF but fid_embed exists: we CONCAT c → g becomes [B, gnn_emb_dim + fid_emb_dim] + self.gnn_out_dim = gnn_emb_dim + (fid_emb_dim if (self.fid_embed is not None and self.film is None) else 0) + + # Task embeddings + self.task_embed = nn.Embedding(self.num_tasks, task_emb_dim) if (use_task_embed and task_emb_dim > 0) else None + + # Per-task heads + head_in_dim = self.gnn_out_dim + (task_emb_dim if self.task_embed is not None else 0) + self.heads = nn.ModuleList([ + TaskHead( + in_dim=head_in_dim, + hidden_dim=head_hidden, + depth=head_depth, + act=head_act, + dropout=head_dropout, + heteroscedastic=heteroscedastic, + ) for _ in range(self.num_tasks) + ]) + + + def reset_parameters(self): + if self.fid_embed is not None: + nn.init.normal_(self.fid_embed.weight, mean=0.0, std=0.02) + if self.task_embed is not None: + nn.init.normal_(self.task_embed.weight, mean=0.0, std=0.02) + # Encoder/heads rely on their internal initializations. + + def forward(self, data: Batch) -> dict: + """ + Returns: + { + "pred": [B, T] means, + "logvar": [B, T] optional if heteroscedastic, + "h": [B, D] graph embedding after FiLM (useful for diagnostics). + } + """ + x, edge_index = data.x, data.edge_index + edge_attr = getattr(data, "edge_attr", None) + batch = data.batch + if edge_attr is None and hasattr(self.encoder, "gnn_type") and self.encoder.gnn_type == "gine": + raise ValueError("GINE encoder requires edge_attr, but Batch.edge_attr is None.") + + # Graph embedding + g = self.encoder(x, edge_index, edge_attr, batch) # [B, D] + + # Fidelity conditioning + fid_idx = data.fid_idx.view(-1).long() # [B] + if self.fid_embed is not None: + c = self.fid_embed(fid_idx) # [B, C] + if self.film is not None: + g = self.film(g, c) # [B, D] + else: + g = torch.cat([g, c], dim=-1) + + # Per-task heads + preds: List[torch.Tensor] = [] + logvars: Optional[List[torch.Tensor]] = [] if self.hetero else None + for t_idx, head in enumerate(self.heads): + if self.task_embed is not None: + tvec = self.task_embed.weight[t_idx].unsqueeze(0).expand(g.size(0), -1) + z = torch.cat([g, tvec], dim=-1) + else: + z = g + out = head(z) # [B, 1] or [B, 2] + if self.hetero: + mu = out[..., 0:1] + lv = out[..., 1:2] + preds.append(mu) + logvars.append(lv) # type: ignore[arg-type] + else: + preds.append(out) + + pred = torch.cat(preds, dim=-1) # [B, T] + result = {"pred": pred, "h": g} + if self.hetero and logvars is not None: + result["logvar"] = torch.cat(logvars, dim=-1) # [B, T] + return result + + def regularization_loss(self) -> torch.Tensor: + """ + Optional small L2 on embeddings to keep them bounded. + """ + device = next(self.parameters()).device + reg = torch.zeros([], device=device) + if self.fid_embed is not None and self.fid_emb_l2 > 0: + reg = reg + self.fid_emb_l2 * (self.fid_embed.weight.pow(2).mean()) + if self.task_embed is not None and self.task_emb_l2 > 0: + reg = reg + self.task_emb_l2 * (self.task_embed.weight.pow(2).mean()) + return reg + + +def build_model( + *, + in_dim_node: int, + in_dim_edge: int, + task_names: List[str], + num_fids: int, + gnn_type: Literal["gine", "gin", "gcn"] = "gine", + gnn_emb_dim: int = 256, + gnn_layers: int = 5, + gnn_norm: Literal["batch", "layer", "none"] = "batch", + gnn_readout: Literal["mean", "sum", "max"] = "mean", + gnn_act: str = "relu", + gnn_dropout: float = 0.0, + gnn_residual: bool = True, + fid_emb_dim: int = 64, + use_film: bool = True, + use_task_embed: bool = True, + task_emb_dim: int = 32, + head_hidden: int = 512, + use_task_uncertainty: bool = False, + head_depth: int = 2, + head_act: str = "relu", + head_dropout: float = 0.0, + heteroscedastic: bool = False, + fid_emb_l2: float = 0.0, + task_emb_l2: float = 0.0, +) -> MultiTaskMultiFidelityModel: + """ + Factory to construct the multi-task, multi-fidelity model with a consistent API. + """ + return MultiTaskMultiFidelityModel( + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + task_names=task_names, + num_fids=num_fids, + gnn_type=gnn_type, + gnn_emb_dim=gnn_emb_dim, + gnn_layers=gnn_layers, + gnn_norm=gnn_norm, + gnn_readout=gnn_readout, + gnn_act=gnn_act, + gnn_dropout=gnn_dropout, + gnn_residual=gnn_residual, + fid_emb_dim=fid_emb_dim, + use_film=use_film, + use_task_embed=use_task_embed, + task_emb_dim=task_emb_dim, + head_hidden=head_hidden, + head_depth=head_depth, + head_act=head_act, + head_dropout=head_dropout, + heteroscedastic=heteroscedastic, + fid_emb_l2=fid_emb_l2, + task_emb_l2=task_emb_l2, + use_task_uncertainty=use_task_uncertainty, + ) diff --git a/src/predictor.py b/src/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..240bfa41ab09d90525a145ddfeb643bfe1586e2d --- /dev/null +++ b/src/predictor.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +import re +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch_geometric.data import Data + +from src.data_builder import featurize_smiles, TargetScaler +from src.model import build_model +from src.utils import to_device, apply_inverse_transform + + +# ------------------------- +# Unit correction (ML only) +# ------------------------- +POST_SCALE = { + "td": 1e-7, + "dif": 1e-5, + "visc": 1e-3, +} + + +def _load_scaler_compat(path: Path) -> TargetScaler: + blob = torch.load(path, map_location="cpu") + if "mean" not in blob or "std" not in blob: + raise RuntimeError(f"Unrecognized target_scaler format: {path}") + + ts = TargetScaler( + transforms=blob.get("transforms", None), + eps=blob.get("eps", None), + ) + ts.load_state_dict({ + "mean": blob["mean"].float(), + "std": blob["std"].float(), + "transforms": blob.get("transforms", ts.transforms), + "eps": blob.get("eps", ts.eps), + }) + ts.targets = [str(t).lower() for t in blob.get("targets", [])] + return ts + + +def _infer_seed_from_name(path: Path) -> Optional[int]: + m = re.search(r"_([0-9]+)\.pt$", path.name) + return int(m.group(1)) if m else None + + +def _make_one_graph(smiles: str) -> Data: + x, edge_index, edge_attr = featurize_smiles(smiles) + d = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + y=torch.zeros(1, 1), + y_mask=torch.zeros(1, 1, dtype=torch.bool), + fid_idx=torch.tensor([0], dtype=torch.long), + ) + d.smiles = smiles + return d + + +class SingleTaskEnsemblePredictor: + """ + Single-task ensemble: + models/single_models/{prop}_single_model_{seed}.pt + models/single_models/{prop}_single_scalar_{seed}.pt + """ + + def __init__(self, models_dir: str = "models/single_models", device: str = "cpu"): + self.models_dir = Path(models_dir) + self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu") + self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {} + + def available_seeds(self, prop: str) -> List[int]: + prop = prop.lower() + seeds = [] + for p in self.models_dir.glob(f"{prop}_single_model_*.pt"): + s = _infer_seed_from_name(p) + if s is not None: + seeds.append(s) + return sorted(set(seeds)) + + def _load_one(self, prop: str, seed: int): + prop = prop.lower() + key = (prop, seed) + if key in self._cache: + return self._cache[key] + + ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt" + scaler_path = self.models_dir / f"{prop}_single_scalar_{seed}.pt" + if not ckpt_path.exists() or not scaler_path.exists(): + raise FileNotFoundError(f"Missing model/scaler for {prop} seed {seed}") + + ckpt = torch.load(ckpt_path, map_location=self.device) + train_args = ckpt.get("args", {}) + + scaler = _load_scaler_compat(scaler_path) + task_names = list(getattr(scaler, "targets", [])) or [prop] + + meta = {"train_args": train_args, "task_names": task_names} + self._cache[key] = (None, scaler, meta) + return self._cache[key] + + def _build_model_if_needed(self, prop: str, seed: int, in_dim_node: int, in_dim_edge: int): + prop = prop.lower() + key = (prop, seed) + model, scaler, meta = self._cache[key] + if model is not None: + return model, scaler, meta + + train_args = meta["train_args"] + task_names = meta["task_names"] + + ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt" + ckpt = torch.load(ckpt_path, map_location=self.device) + state_dict = ckpt["model"] + + # infer num_fids from checkpoint + if "fid_embed.weight" in state_dict: + num_fids = state_dict["fid_embed.weight"].shape[0] + else: + num_fids = 1 + + model = build_model( + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + task_names=task_names, + num_fids=num_fids, + gnn_type=train_args.get("gnn_type", "gine"), + gnn_emb_dim=train_args.get("gnn_emb_dim", 256), + gnn_layers=train_args.get("gnn_layers", 5), + gnn_norm=train_args.get("gnn_norm", "batch"), + gnn_readout=train_args.get("gnn_readout", "mean"), + gnn_act=train_args.get("gnn_act", "relu"), + gnn_dropout=train_args.get("gnn_dropout", 0.0), + gnn_residual=train_args.get("gnn_residual", True), + fid_emb_dim=train_args.get("fid_emb_dim", 64), + use_film=train_args.get("use_film", True), + use_task_embed=train_args.get("use_task_embed", True), + task_emb_dim=train_args.get("task_emb_dim", 32), + head_hidden=train_args.get("head_hidden", 512), + head_depth=train_args.get("head_depth", 2), + head_act=train_args.get("head_act", "relu"), + head_dropout=train_args.get("head_dropout", 0.0), + heteroscedastic=train_args.get("heteroscedastic", False), + fid_emb_l2=0.0, + task_emb_l2=0.0, + use_task_uncertainty=train_args.get("task_uncertainty", False), + ).to(self.device) + + model.load_state_dict(state_dict, strict=True) + model.eval() + + self._cache[key] = (model, scaler, meta) + return model, scaler, meta + + def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]: + prop = prop.lower() + seeds = self.available_seeds(prop) + if not seeds: + return None, None, {} + + try: + g = _make_one_graph(smiles) + except Exception: + return None, None, {} + + in_dim_node = g.x.shape[1] + in_dim_edge = g.edge_attr.shape[1] + + per_seed: Dict[int, float] = {} + with torch.no_grad(): + for seed in seeds: + self._load_one(prop, seed) + model, scaler, meta = self._build_model_if_needed(prop, seed, in_dim_node, in_dim_edge) + + batch = to_device(g, self.device) + out = model(batch) + pred_n = out["pred"] # [1, 1] + pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1) + val = float(pred[0]) + + # unit correction + val *= POST_SCALE.get(prop, 1.0) + + per_seed[seed] = val + + vals = np.array(list(per_seed.values()), dtype=float) + mean = float(vals.mean()) + std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0 + return mean, std, per_seed diff --git a/src/predictor_multitask.py b/src/predictor_multitask.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7d0ee77161d6a3147b862911f1c9f9d6ad489f --- /dev/null +++ b/src/predictor_multitask.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import re +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch_geometric.data import Data + +from src.data_builder import featurize_smiles, TargetScaler +from src.model import build_model +from src.utils import to_device, apply_inverse_transform + + +# ------------------------- +# Unit correction (ML only) +# ------------------------- +POST_SCALE = { + "td": 1e-7, + "dif": 1e-5, + "visc": 1e-3, +} + + +def _load_scaler_compat(path: Path) -> TargetScaler: + blob = torch.load(path, map_location="cpu") + if "mean" not in blob or "std" not in blob: + raise RuntimeError(f"Unrecognized target_scaler format: {path}") + + ts = TargetScaler( + transforms=blob.get("transforms", None), + eps=blob.get("eps", None), + ) + ts.load_state_dict({ + "mean": blob["mean"].float(), + "std": blob["std"].float(), + "transforms": blob.get("transforms", ts.transforms), + "eps": blob.get("eps", ts.eps), + }) + ts.targets = [str(t).lower() for t in blob.get("targets", [])] + return ts + + +def _infer_seed(path: Path) -> Optional[int]: + m = re.search(r"_([0-9]+)\.pt$", path.name) + return int(m.group(1)) if m else None + + +def _make_one_graph(smiles: str, T: int, fid_idx: int = 0) -> Data: + x, edge_index, edge_attr = featurize_smiles(smiles) + d = Data( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + y=torch.zeros(1, T), + y_mask=torch.zeros(1, T, dtype=torch.bool), + fid_idx=torch.tensor([fid_idx], dtype=torch.long), + ) + d.smiles = smiles + return d + + +class MultiTaskEnsemblePredictor: + """ + Multi-task ensemble: + models/multitask_models/{task}_model_{seed}.pt + models/multitask_models/{task}_scalar_{seed}.pt + """ + + def __init__(self, models_dir: str = "models/multitask_models", device: str = "cpu"): + self.models_dir = Path(models_dir) + self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu") + self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {} + + def available_seeds(self, task: str) -> List[int]: + task = task.strip().lower() + seeds = [] + for p in self.models_dir.glob(f"{task}_model_*.pt"): + s = _infer_seed(p) + if s is not None: + seeds.append(s) + return sorted(set(seeds)) + + def _load_one_meta(self, task: str, seed: int): + task = task.strip().lower() + key = (task, seed) + if key in self._cache: + return self._cache[key] + + ckpt_path = self.models_dir / f"{task}_model_{seed}.pt" + scaler_path = self.models_dir / f"{task}_scalar_{seed}.pt" + if not ckpt_path.exists() or not scaler_path.exists(): + raise FileNotFoundError(f"Missing model/scaler for task={task} seed={seed}") + + ckpt = torch.load(ckpt_path, map_location=self.device) + state_dict = ckpt["model"] + train_args = ckpt.get("args", {}) + + scaler = _load_scaler_compat(scaler_path) + task_names = list(getattr(scaler, "targets", [])) + if not task_names: + raise RuntimeError(f"No targets found in scaler: {scaler_path}") + + if "fid_embed.weight" in state_dict: + num_fids = state_dict["fid_embed.weight"].shape[0] + else: + num_fids = 1 + + meta = { + "train_args": train_args, + "task_names": task_names, + "num_fids": num_fids, + } + self._cache[key] = (None, scaler, meta) + return self._cache[key] + + def _build_if_needed(self, task: str, seed: int, in_dim_node: int, in_dim_edge: int): + task = task.strip().lower() + key = (task, seed) + model, scaler, meta = self._cache[key] + if model is not None: + return model, scaler, meta + + train_args = meta["train_args"] + task_names = meta["task_names"] + num_fids = meta["num_fids"] + + model = build_model( + in_dim_node=in_dim_node, + in_dim_edge=in_dim_edge, + task_names=task_names, + num_fids=num_fids, + gnn_type=train_args.get("gnn_type", "gine"), + gnn_emb_dim=train_args.get("gnn_emb_dim", 256), + gnn_layers=train_args.get("gnn_layers", 5), + gnn_norm=train_args.get("gnn_norm", "batch"), + gnn_readout=train_args.get("gnn_readout", "mean"), + gnn_act=train_args.get("gnn_act", "relu"), + gnn_dropout=train_args.get("gnn_dropout", 0.0), + gnn_residual=train_args.get("gnn_residual", True), + fid_emb_dim=train_args.get("fid_emb_dim", 64), + use_film=train_args.get("use_film", True), + use_task_embed=train_args.get("use_task_embed", True), + task_emb_dim=train_args.get("task_emb_dim", 32), + head_hidden=train_args.get("head_hidden", 512), + head_depth=train_args.get("head_depth", 2), + head_act=train_args.get("head_act", "relu"), + head_dropout=train_args.get("head_dropout", 0.0), + heteroscedastic=train_args.get("heteroscedastic", False), + fid_emb_l2=0.0, + task_emb_l2=0.0, + use_task_uncertainty=train_args.get("task_uncertainty", False), + ).to(self.device) + + ckpt_path = self.models_dir / f"{task}_model_{seed}.pt" + ckpt = torch.load(ckpt_path, map_location=self.device) + model.load_state_dict(ckpt["model"], strict=True) + model.eval() + + self._cache[key] = (model, scaler, meta) + return model, scaler, meta + + def predict_mean_std(self, smiles: str, prop_key: str, task: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]: + task = task.strip().lower() + prop_key = prop_key.lower() + + seeds = self.available_seeds(task) + if not seeds: + return None, None, {} + + self._load_one_meta(task, seeds[0]) + _, scaler0, meta0 = self._cache[(task, seeds[0])] + targets = list(meta0["task_names"]) # already lower() + if prop_key not in targets: + return None, None, {} + + t_idx = targets.index(prop_key) + T = len(targets) + + try: + g = _make_one_graph(smiles, T=T, fid_idx=0) + except Exception: + return None, None, {} + + in_dim_node = g.x.shape[1] + in_dim_edge = g.edge_attr.shape[1] + + per_seed: Dict[int, float] = {} + with torch.no_grad(): + for seed in seeds: + self._load_one_meta(task, seed) + model, scaler, meta = self._build_if_needed(task, seed, in_dim_node, in_dim_edge) + + batch = to_device(g, self.device) + out = model(batch) + pred_n = out["pred"] # [1, T] + pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1) + val = float(pred[t_idx]) + + # unit correction + val *= POST_SCALE.get(prop_key, 1.0) + + per_seed[seed] = val + + vals = np.array(list(per_seed.values()), dtype=float) + mean = float(vals.mean()) + std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0 + return mean, std, per_seed diff --git a/src/predictor_router.py b/src/predictor_router.py new file mode 100644 index 0000000000000000000000000000000000000000..8b66ec258570d8f3133cad5e155f6d7921c495a7 --- /dev/null +++ b/src/predictor_router.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Dict, Optional, Tuple + +from src.predictor import SingleTaskEnsemblePredictor +from src.predictor_multitask import MultiTaskEnsemblePredictor + + +class RouterPredictor: + """ + Routes each property to either: + - single-task ensemble (models/single_models) + - multitask ensemble (models/multitask_models/{task}_*) + based on models/best_model_map.json + """ + + def __init__( + self, + map_path: str = "models/best_model_map.json", + single_dir: str = "models/single_models", + multitask_dir: str = "models/multitask_models", + device: str = "cpu", + ): + self.map_path = Path(map_path) + self.map: Dict[str, dict] = json.load(open(self.map_path)) + self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device) + self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device) + + def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]: + prop = prop.lower() + cfg = self.map.get(prop, {"family": "single"}) + + fam = cfg.get("family", "single").lower() + if fam == "multitask": + task = str(cfg.get("task", "all")).lower() + mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task) + label = f"multitask:{task}" + return mean, std, per_seed, label + + # default: single + mean, std, per_seed = self.single.predict_mean_std(smiles, prop) + label = "single" + return mean, std, per_seed, label diff --git a/src/rnn_smiles/__init__.py b/src/rnn_smiles/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3979e166d0058335329a9ceb6a0cf035aac32b --- /dev/null +++ b/src/rnn_smiles/__init__.py @@ -0,0 +1,22 @@ +"""RNN-based SMILES generation helpers for Streamlit pages.""" + +from .generator import ( + canonicalize_smiles, + filter_novel_smiles, + generate_smiles, + load_existing_smiles_set, + load_rnn_model, +) +from .rnn import MultiGRU, RNN +from .vocabulary import Vocabulary + +__all__ = [ + "canonicalize_smiles", + "filter_novel_smiles", + "generate_smiles", + "load_existing_smiles_set", + "load_rnn_model", + "MultiGRU", + "RNN", + "Vocabulary", +] diff --git a/src/rnn_smiles/__pycache__/__init__.cpython-310.pyc b/src/rnn_smiles/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31953857b27eca170168d047d6761d74c20eefd7 Binary files /dev/null and b/src/rnn_smiles/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/rnn_smiles/__pycache__/__init__.cpython-313.pyc b/src/rnn_smiles/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75cdee29ef22589fcc66e6eb069f43c742fb8bf5 Binary files /dev/null and b/src/rnn_smiles/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/rnn_smiles/__pycache__/generator.cpython-310.pyc b/src/rnn_smiles/__pycache__/generator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f54493a42a28e8f675ba55f18cfba357479f90f5 Binary files /dev/null and b/src/rnn_smiles/__pycache__/generator.cpython-310.pyc differ diff --git a/src/rnn_smiles/__pycache__/generator.cpython-313.pyc b/src/rnn_smiles/__pycache__/generator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be4538d70b4c9ae6b4da564e505767ffe0c1ff2a Binary files /dev/null and b/src/rnn_smiles/__pycache__/generator.cpython-313.pyc differ diff --git a/src/rnn_smiles/__pycache__/rnn.cpython-310.pyc b/src/rnn_smiles/__pycache__/rnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22b99587e44351c02be84c2fad29be064213b866 Binary files /dev/null and b/src/rnn_smiles/__pycache__/rnn.cpython-310.pyc differ diff --git a/src/rnn_smiles/__pycache__/rnn.cpython-313.pyc b/src/rnn_smiles/__pycache__/rnn.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7203b56c8e829433ee476c21cde0c48752dde4d7 Binary files /dev/null and b/src/rnn_smiles/__pycache__/rnn.cpython-313.pyc differ diff --git a/src/rnn_smiles/__pycache__/utils.cpython-313.pyc b/src/rnn_smiles/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42e2fa52ef5ab24656c2d03342a0cdf484222139 Binary files /dev/null and b/src/rnn_smiles/__pycache__/utils.cpython-313.pyc differ diff --git a/src/rnn_smiles/__pycache__/vocabulary.cpython-310.pyc b/src/rnn_smiles/__pycache__/vocabulary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8ccbc4b577a1a7a5b4ab8a7b10d11da849610c0 Binary files /dev/null and b/src/rnn_smiles/__pycache__/vocabulary.cpython-310.pyc differ diff --git a/src/rnn_smiles/__pycache__/vocabulary.cpython-313.pyc b/src/rnn_smiles/__pycache__/vocabulary.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bec8e49cafb9cc9786505f2f023a8457f12559bd Binary files /dev/null and b/src/rnn_smiles/__pycache__/vocabulary.cpython-313.pyc differ diff --git a/src/rnn_smiles/generator.py b/src/rnn_smiles/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..7409d5c18e054233355a0a2dc7f68effbb58819b --- /dev/null +++ b/src/rnn_smiles/generator.py @@ -0,0 +1,226 @@ +"""Streamlit integration helpers for RNN SMILES generation.""" + +from __future__ import annotations + +import pickle +from pathlib import Path +from typing import Iterable, Sequence + +import pandas as pd +import streamlit as st +import torch +from rdkit import Chem, RDLogger + +from .rnn import RNN +from .vocabulary import Vocabulary + +RDLogger.DisableLog("rdApp.*") + + +def canonicalize_smiles(smiles: str) -> str | None: + s = (smiles or "").strip() + if not s: + return None + mol = Chem.MolFromSmiles(s) + if mol is None: + return None + return Chem.MolToSmiles(mol, canonical=True) + + +def _find_smiles_column(path: Path) -> str | None: + try: + header = pd.read_csv(path, nrows=0) + except Exception: + return None + + columns = [str(c) for c in header.columns] + norm_to_col = {str(c).strip().lower(): c for c in columns} + + for candidate in ["smiles", "canonical_smiles", "canonical smiles", "smile", "smi"]: + if candidate in norm_to_col: + return norm_to_col[candidate] + + for norm, col in norm_to_col.items(): + if "smiles" in norm: + return col + + return None + + +def _load_checkpoint(path: Path, device: torch.device) -> dict: + # Prefer secure mode; allow trusted local fallback for legacy checkpoints. + path = path.expanduser().resolve() + try: + with path.open("r", encoding="utf-8") as fh: + first = fh.readline().strip() + if first == "version https://git-lfs.github.com/spec/v1": + raise RuntimeError( + "Checkpoint is a Git LFS pointer, not model weights. " + "Replace it with the real .ckpt file before running generation." + ) + except UnicodeDecodeError: + # Binary checkpoint (expected). + pass + + trusted_root = (Path(__file__).resolve().parents[2] / "models").resolve() + try: + state = torch.load(path, map_location=device, weights_only=True) + except TypeError: + state = torch.load(path, map_location=device) + except (pickle.UnpicklingError, RuntimeError) as exc: + weights_only_failure = isinstance(exc, pickle.UnpicklingError) or ( + "Weights only load failed" in str(exc) + ) + if not weights_only_failure: + raise + if not path.is_relative_to(trusted_root): + raise RuntimeError( + "Refusing unsafe checkpoint load outside the local models directory." + ) from exc + state = torch.load(path, map_location=device, weights_only=False) + if isinstance(state, dict) and isinstance(state.get("state_dict"), dict): + state = state["state_dict"] + if not isinstance(state, dict): + raise RuntimeError(f"Checkpoint does not contain a state dict: {path}") + return state + + +@st.cache_resource(show_spinner=False) +def load_rnn_model(ckpt_path: str | Path, voc_path: str | Path) -> tuple[RNN, Vocabulary]: + ckpt_path = Path(ckpt_path).expanduser().resolve() + voc_path = Path(voc_path).expanduser().resolve() + + if not ckpt_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + if not voc_path.exists(): + raise FileNotFoundError(f"Vocabulary not found: {voc_path}") + + voc = Vocabulary(init_from_file=str(voc_path)) + model = RNN(voc) + model_device = next(model.rnn.parameters()).device + state = _load_checkpoint(ckpt_path, model_device) + + ckpt_vocab_size = None + if "embedding.weight" in state: + ckpt_vocab_size = int(state["embedding.weight"].shape[0]) + if ckpt_vocab_size is not None and ckpt_vocab_size != voc.vocab_size: + raise RuntimeError( + f"Vocabulary size mismatch: voc has {voc.vocab_size} tokens, " + f"checkpoint expects {ckpt_vocab_size}. " + "Use the matching vocab file for this checkpoint." + ) + + model.rnn.load_state_dict(state) + model.rnn.eval() + return model, voc + + +def _sample_with_temperature( + model: RNN, voc: Vocabulary, batch_size: int, max_length: int, temperature: float +) -> torch.Tensor: + temp = max(float(temperature), 1e-6) + device = next(model.rnn.parameters()).device + start_token = torch.full((batch_size,), voc.vocab["GO"], dtype=torch.long, device=device) + h = model.rnn.init_h(batch_size) + x = start_token + + sequences: list[torch.Tensor] = [] + finished = torch.zeros(batch_size, dtype=torch.bool, device=device) + + for _ in range(max_length): + logits, h = model.rnn(x, h) + logits = logits / temp + prob = torch.softmax(logits, dim=1) + x = torch.multinomial(prob, 1).view(-1) + sequences.append(x.view(-1, 1)) + finished = finished | (x == voc.vocab["EOS"]) + if torch.all(finished): + break + + if not sequences: + return torch.empty((batch_size, 0), dtype=torch.long, device=device) + return torch.cat(sequences, dim=1) + + +def generate_smiles( + model: RNN, + voc: Vocabulary, + n: int, + max_length: int, + temperature: float = 1.0, +) -> list[str]: + if n <= 0: + return [] + max_length = max(int(max_length), 1) + + with torch.no_grad(): + if abs(float(temperature) - 1.0) < 1e-8: + seqs, _, _ = model.sample(int(n), max_length=max_length) + else: + seqs = _sample_with_temperature( + model, + voc, + int(n), + max_length, + float(temperature), + ) + arr = seqs.detach().cpu().numpy() + + output: list[str] = [] + for seq in arr: + output.append(voc.decode(seq)) + return output + + +def filter_novel_smiles(smiles: Iterable[str], existing: set[str]) -> list[str]: + novel: list[str] = [] + seen: set[str] = set() + for smi in smiles: + canonical = canonicalize_smiles(smi) + if canonical is None: + continue + if canonical in seen: + continue + seen.add(canonical) + if canonical in existing: + continue + novel.append(canonical) + return novel + + +@st.cache_resource(show_spinner=False) +def load_existing_smiles_set(csv_paths: Sequence[str | Path], chunksize: int = 200_000) -> set[str]: + existing: set[str] = set() + for p in csv_paths: + path = Path(p) + if not path.exists(): + continue + col = _find_smiles_column(path) + if col is None: + # Skip malformed files or tables without a recognizable SMILES column. + continue + try: + reader = pd.read_csv( + path, + usecols=[col], + chunksize=int(chunksize), + on_bad_lines="skip", + ) + for chunk in reader: + for smiles in chunk[col].astype(str): + canonical = canonicalize_smiles(smiles) + if canonical: + existing.add(canonical) + except Exception: + # Skip files that fail to parse to keep generation usable. + continue + return existing + + +__all__ = [ + "canonicalize_smiles", + "load_rnn_model", + "generate_smiles", + "filter_novel_smiles", + "load_existing_smiles_set", +] diff --git a/src/rnn_smiles/rnn.py b/src/rnn_smiles/rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5056b37933a23c7b388d99f75ae161774badf2 --- /dev/null +++ b/src/rnn_smiles/rnn.py @@ -0,0 +1,89 @@ +"""Core GRU model used for polymer SMILES generation.""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MultiGRU(nn.Module): + def __init__(self, vocab_size: int): + super().__init__() + self.embedding = nn.Embedding(vocab_size, 128) + self.gru_1 = nn.GRUCell(128, 512) + self.gru_2 = nn.GRUCell(512, 512) + self.gru_3 = nn.GRUCell(512, 512) + self.linear = nn.Linear(512, vocab_size) + + def forward(self, x: torch.Tensor, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + x = self.embedding(x) + h_out = torch.zeros_like(h) + x = h_out[0] = self.gru_1(x, h[0]) + x = h_out[1] = self.gru_2(x, h[1]) + x = h_out[2] = self.gru_3(x, h[2]) + x = self.linear(x) + return x, h_out + + def init_h(self, batch_size: int) -> torch.Tensor: + device = next(self.parameters()).device + return torch.zeros(3, batch_size, 512, device=device) + + +def nll_loss(log_probs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + # Gather selected token log-probability for each sample in batch. + return log_probs.gather(1, targets.contiguous().view(-1, 1)).squeeze(1) + + +class RNN: + def __init__(self, voc): + self.rnn = MultiGRU(voc.vocab_size) + if torch.cuda.is_available(): + self.rnn.cuda() + self.voc = voc + + def likelihood(self, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_length = target.size() + device = target.device + start_token = torch.full((batch_size, 1), self.voc.vocab["GO"], dtype=torch.long, device=device) + x = torch.cat((start_token, target[:, :-1]), 1) + h = self.rnn.init_h(batch_size) + + log_probs = torch.zeros(batch_size, device=device) + entropy = torch.zeros(batch_size, device=device) + for step in range(seq_length): + logits, h = self.rnn(x[:, step], h) + log_prob = F.log_softmax(logits, dim=1) + prob = F.softmax(logits, dim=1) + log_probs += nll_loss(log_prob, target[:, step]) + entropy += -torch.sum((log_prob * prob), 1) + return log_probs, entropy + + def sample(self, batch_size: int, max_length: int = 140) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = next(self.rnn.parameters()).device + start_token = torch.full((batch_size,), self.voc.vocab["GO"], dtype=torch.long, device=device) + h = self.rnn.init_h(batch_size) + x = start_token + + sequences: list[torch.Tensor] = [] + log_probs = torch.zeros(batch_size, device=device) + finished = torch.zeros(batch_size, dtype=torch.bool, device=device) + entropy = torch.zeros(batch_size, device=device) + + for _ in range(max_length): + logits, h = self.rnn(x, h) + prob = F.softmax(logits, dim=1) + log_prob = F.log_softmax(logits, dim=1) + x = torch.multinomial(prob, 1).view(-1) + sequences.append(x.view(-1, 1)) + log_probs += nll_loss(log_prob, x) + entropy += -torch.sum((log_prob * prob), 1) + finished = finished | (x == self.voc.vocab["EOS"]) + if torch.all(finished): + break + + if sequences: + stacked = torch.cat(sequences, 1) + else: + stacked = torch.empty((batch_size, 0), dtype=torch.long, device=device) + return stacked, log_probs, entropy diff --git a/src/rnn_smiles/utils.py b/src/rnn_smiles/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd72739d8f75e62e5d677286e57434576696b5e --- /dev/null +++ b/src/rnn_smiles/utils.py @@ -0,0 +1,15 @@ +"""Utility helpers used by the legacy-style RNN generator.""" + +from __future__ import annotations + +import numpy as np +import torch + + +def variable(tensor: torch.Tensor | np.ndarray) -> torch.Tensor: + """Return a tensor on GPU when available.""" + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor) + if torch.cuda.is_available(): + return tensor.cuda() + return tensor diff --git a/src/rnn_smiles/vocabulary.py b/src/rnn_smiles/vocabulary.py new file mode 100644 index 0000000000000000000000000000000000000000..e57a7e4c7080b4c0e31019e1b46fa17934fee4a2 --- /dev/null +++ b/src/rnn_smiles/vocabulary.py @@ -0,0 +1,69 @@ +"""Token vocabulary used by the SMILES RNN.""" + +from __future__ import annotations + +import re + +import numpy as np + + +class Vocabulary: + def __init__(self, init_from_file: str | None = None, max_length: int | None = None): + self.special_tokens = ["EOS", "GO"] + self.additional_chars: set[str] = set() + self.chars = self.special_tokens + self.vocab_size = len(self.chars) + self.vocab = dict(zip(self.chars, range(len(self.chars)))) + self.reversed_vocab = {v: k for k, v in self.vocab.items()} + self.max_length = max_length + if init_from_file: + self.init_from_file(init_from_file) + + def encode(self, char_list: list[str]) -> np.ndarray: + smiles_matrix = np.zeros(len(char_list), dtype=np.float32) + for i, char in enumerate(char_list): + smiles_matrix[i] = self.vocab[char] + return smiles_matrix + + def decode(self, matrix: np.ndarray) -> str: + chars: list[str] = [] + eos_id = self.vocab["EOS"] + for i in matrix: + if int(i) == eos_id: + break + chars.append(self.reversed_vocab[int(i)]) + return "".join(chars) + + def tokenize(self, smiles: str) -> list[str]: + regex = r"(\[[^\[\]]{1,6}\])" + char_list = re.split(regex, smiles) + tokenized: list[str] = [] + for char in char_list: + if not char: + continue + if char.startswith("["): + tokenized.append(char) + else: + tokenized.extend(list(char)) + tokenized.append("EOS") + return tokenized + + def add_characters(self, chars: list[str]) -> None: + for char in chars: + self.additional_chars.add(char) + char_list = sorted(list(self.additional_chars)) + self.chars = char_list + self.special_tokens + self.vocab_size = len(self.chars) + self.vocab = dict(zip(self.chars, range(len(self.chars)))) + self.reversed_vocab = {v: k for k, v in self.vocab.items()} + + def init_from_file(self, file_path: str) -> None: + with open(file_path, "r", encoding="utf-8") as f: + chars = f.read().split() + self.add_characters(chars) + + def __len__(self) -> int: + return len(self.chars) + + def __str__(self) -> str: + return f"Vocabulary containing {len(self)} tokens: {self.chars}" diff --git a/src/sascorer.py b/src/sascorer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba618acaffff2a535ac109d766c559acab301c1c --- /dev/null +++ b/src/sascorer.py @@ -0,0 +1,192 @@ +# +# calculation of synthetic accessibility score as described in: +# +# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions +# Peter Ertl and Ansgar Schuffenhauer +# Journal of Cheminformatics 1:8 (2009) +# http://www.jcheminf.com/content/1/1/8 +# +# several small modifications to the original paper are included +# particularly slightly different formula for marocyclic penalty +# and taking into account also molecule symmetry (fingerprint density) +# +# for a set of 10k diverse molecules the agreement between the original method +# as implemented in PipelinePilot and this implementation is r2 = 0.97 +# +# peter ertl & greg landrum, september 2013 +# + +from rdkit import Chem +from rdkit.Chem import rdFingerprintGenerator, rdMolDescriptors + +import math +import pickle + +import os.path as op + +_fscores = None +mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2) + + +def readFragmentScores(name="fpscores.pkl.gz"): + import gzip + global _fscores + # generate the full path filename: + if name == "fpscores.pkl.gz": + name = op.join(op.dirname(__file__), name) + data = pickle.load(gzip.open(name)) + outDict = {} + for i in data: + for j in range(1, len(i)): + outDict[i[j]] = float(i[0]) + _fscores = outDict + + +def numBridgeheadsAndSpiro(mol, ri=None): + nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) + nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) + return nBridgehead, nSpiro + + +def calculateScore(m): + + if not m.GetNumAtoms(): + return None + + if _fscores is None: + readFragmentScores() + + # fragment score + sfp = mfpgen.GetSparseCountFingerprint(m) + + score1 = 0. + nf = 0 + nze = sfp.GetNonzeroElements() + for id, count in nze.items(): + nf += count + score1 += _fscores.get(id, -4) * count + + score1 /= nf + + # features score + nAtoms = m.GetNumAtoms() + nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) + ri = m.GetRingInfo() + nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) + nMacrocycles = 0 + for x in ri.AtomRings(): + if len(x) > 8: + nMacrocycles += 1 + + sizePenalty = nAtoms**1.005 - nAtoms + stereoPenalty = math.log10(nChiralCenters + 1) + spiroPenalty = math.log10(nSpiro + 1) + bridgePenalty = math.log10(nBridgeheads + 1) + macrocyclePenalty = 0. + # --------------------------------------- + # This differs from the paper, which defines: + # macrocyclePenalty = math.log10(nMacrocycles+1) + # This form generates better results when 2 or more macrocycles are present + if nMacrocycles > 0: + macrocyclePenalty = math.log10(2) + + score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty + + # correction for the fingerprint density + # not in the original publication, added in version 1.1 + # to make highly symmetrical molecules easier to synthetise + score3 = 0. + numBits = len(nze) + if nAtoms > numBits: + score3 = math.log(float(nAtoms) / numBits) * .5 + + sascore = score1 + score2 + score3 + + # need to transform "raw" value into scale between 1 and 10 + min = -4.0 + max = 2.5 + sascore = 11. - (sascore - min + 1) / (max - min) * 9. + + # smooth the 10-end + if sascore > 8.: + sascore = 8. + math.log(sascore + 1. - 9.) + if sascore > 10.: + sascore = 10.0 + elif sascore < 1.: + sascore = 1.0 + + return sascore + + +def processMols(mols): + print('smiles\tName\tsa_score') + for i, m in enumerate(mols): + if m is None: + continue + + s = calculateScore(m) + + smiles = Chem.MolToSmiles(m) + if s is None: + print(f"{smiles}\t{m.GetProp('_Name')}\t{s}") + else: + print(f"{smiles}\t{m.GetProp('_Name')}\t{s:3f}") + + +if __name__ == '__main__': + import sys + import time + + t1 = time.time() + if len(sys.argv) == 2: + readFragmentScores() + else: + readFragmentScores(sys.argv[2]) + t2 = time.time() + + molFile = sys.argv[1] + if molFile.endswith("smi"): + suppl = Chem.SmilesMolSupplier(molFile) + elif molFile.endswith("sdf"): + suppl = Chem.SDMolSupplier(molFile) + else: + print(f"Unrecognized file extension for {molFile}") + sys.exit() + + t3 = time.time() + processMols(suppl) + t4 = time.time() + + print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), + file=sys.stderr) + +# +# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# * Neither the name of Novartis Institutes for BioMedical Research Inc. +# nor the names of its contributors may be used to endorse or promote +# products derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# diff --git a/src/streamlit_app.py b/src/streamlit_app.py new file mode 100644 index 0000000000000000000000000000000000000000..99d0b84662681e7d21a08fcce44908344fa86f80 --- /dev/null +++ b/src/streamlit_app.py @@ -0,0 +1,40 @@ +import altair as alt +import numpy as np +import pandas as pd +import streamlit as st + +""" +# Welcome to Streamlit! + +Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:. +If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community +forums](https://discuss.streamlit.io). + +In the meantime, below is an example of what you can do with just a few lines of code: +""" + +num_points = st.slider("Number of points in spiral", 1, 10000, 1100) +num_turns = st.slider("Number of turns in spiral", 1, 300, 31) + +indices = np.linspace(0, 1, num_points) +theta = 2 * np.pi * num_turns * indices +radius = indices + +x = radius * np.cos(theta) +y = radius * np.sin(theta) + +df = pd.DataFrame({ + "x": x, + "y": y, + "idx": indices, + "rand": np.random.randn(num_points), +}) + +st.altair_chart(alt.Chart(df, height=700, width=700) + .mark_point(filled=True) + .encode( + x=alt.X("x", axis=None), + y=alt.Y("y", axis=None), + color=alt.Color("idx", legend=None, scale=alt.Scale()), + size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])), + )) \ No newline at end of file diff --git a/src/ui_style.py b/src/ui_style.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d7a0733b394c9c4b61a58f76186e97db2b5602 --- /dev/null +++ b/src/ui_style.py @@ -0,0 +1,1003 @@ +import base64 +import html +import os +from pathlib import Path +from urllib import request + +import streamlit as st + + +def _icon_data_uri(filename: str) -> str: + icon_path = Path(__file__).resolve().parent.parent / "icons" / filename + if not icon_path.exists(): + return "" + try: + encoded = base64.b64encode(icon_path.read_bytes()).decode("ascii") + except Exception: + return "" + return f"data:image/png;base64,{encoded}" + + +def _config_value(name: str, default: str = "") -> str: + try: + if name in st.secrets: + return str(st.secrets[name]).strip() + except Exception: + pass + return str(os.getenv(name, default)).strip() + + +def _build_sidebar_icon_css() -> str: + fallback = { + 1: "🏠", + 2: "🔎", + 3: "📦", + 4: "🧬", + 5: "⚙️", + 6: "🧠", + 7: "✨", + 8: "💬", + 9: "📚", + } + icon_name = { + 1: "home1.png", + 2: "probe1.png", + 3: "batch1.png", + 4: "molecule1.png", + 5: "manual1.png", + 6: "ai1.png", + 7: "rnn1.png", + 8: "literature.png", + 9: "feedback.png", + } + rules = [ + '[data-testid="stSidebarNav"] ul li a { position: relative; padding-left: 3.25rem !important; }', + '[data-testid="stSidebarNav"] ul li a::before { content: ""; position: absolute; left: 12px; top: 50%; transform: translateY(-50%); width: 32px; height: 32px; background-size: contain; background-repeat: no-repeat; background-position: center; }', + ] + for idx in range(1, 10): + uri = _icon_data_uri(icon_name[idx]) + if uri: + rules.append( + '[data-testid="stSidebarNav"] ul li:nth-of-type(%d) a::before { content: ""; background-image: url("%s"); }' + % (idx, uri) + ) + else: + emoji = fallback[idx] + rules.append( + '[data-testid="stSidebarNav"] ul li:nth-of-type(%d) a::before { content: "%s"; background-image: none; width: auto; height: auto; font-size: 1.4rem; }' + % (idx, emoji) + ) + return "\n".join(rules) + + +def _log_visit_once_per_session() -> None: + if st.session_state.get("_visit_logged"): + return + webhook_url = _config_value("FEEDBACK_WEBHOOK_URL", "") + webhook_token = _config_value("FEEDBACK_WEBHOOK_TOKEN", "") + if not webhook_url: + return + endpoint = webhook_url + sep = "&" if "?" in webhook_url else "?" + endpoint = f"{webhook_url}{sep}event=visit" + if webhook_token: + endpoint = f"{endpoint}&token={webhook_token}" + try: + with request.urlopen(endpoint, timeout=3): + pass + except Exception: + pass + st.session_state["_visit_logged"] = True + + +def render_page_header(title: str, subtitle: str = "", badge: str = "") -> None: + title_html = html.escape(title) + subtitle_html = html.escape(subtitle) if subtitle else "" + badge_html = html.escape(badge) if badge else "" + + st.markdown( + f""" +
+ {"" + badge_html + "" if badge_html else ""} +

{title_html}

+ {"

" + subtitle_html + "

" if subtitle_html else ""} +
+""", + unsafe_allow_html=True, + ) + + +def apply_global_style() -> None: + _log_visit_once_per_session() + icon_css = _build_sidebar_icon_css() + css = """ + + """ + st.markdown(css.replace("__ICON_CSS__", icon_css), unsafe_allow_html=True) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..232361e87e74406c17575e4a23638cb48157b402 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,338 @@ +# utils.py +from __future__ import annotations + +from typing import Dict, List, Optional, Sequence, Literal + +import math +import numpy as np +import torch +import torch.nn as nn + +# Re-exported conveniences from data_builder +from src.data_builder import TargetScaler, grouped_split_by_smiles # noqa: F401 + + +# --------------------------------------------------------- +# Seeding and device helpers +# --------------------------------------------------------- + +def seed_everything(seed: int) -> None: + """Deterministically seed Python, NumPy, and PyTorch (CPU/CUDA).""" + import random + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def to_device(batch, device: torch.device): + """Move a PyG Batch or simple dict of tensors to device.""" + if hasattr(batch, "to"): + return batch.to(device) + if isinstance(batch, dict): + return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} + return batch + + +# --------------------------------------------------------- +# Masked metrics (canonical) +# --------------------------------------------------------- + +def _safe_div(num: torch.Tensor, den: torch.Tensor) -> torch.Tensor: + den = torch.clamp(den, min=1e-12) + return num / den + + +def masked_mse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, + reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor: + """ + pred/target: [B, T]; mask: [B, T] bool + """ + pred, target = pred.float(), target.float() + mask = mask.bool() + se = ((pred - target) ** 2) * mask + if reduction == "sum": + return se.sum() + return _safe_div(se.sum(), mask.sum().float()) + + +def masked_mae(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, + reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor: + ae = (pred - target).abs() * mask + if reduction == "sum": + return ae.sum() + return _safe_div(ae.sum(), mask.sum().float()) + + +def masked_rmse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return torch.sqrt(masked_mse(pred, target, mask, reduction="mean")) + + +def masked_r2(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """ + Masked coefficient of determination across all elements jointly. + """ + pred, target = pred.float(), target.float() + mask = mask.bool() + count = mask.sum().float().clamp(min=1.0) + mean = _safe_div((target * mask).sum(), count) + sst = (((target - mean) ** 2) * mask).sum() + sse = (((target - pred) ** 2) * mask).sum() + return 1.0 - _safe_div(sse, sst.clamp(min=1e-12)) + + +def masked_metrics_overall(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> Dict[str, float]: + return { + "rmse": float(masked_rmse(pred, target, mask).detach().cpu()), + "mae": float(masked_mae(pred, target, mask).detach().cpu()), + "r2": float(masked_r2(pred, target, mask).detach().cpu()), + } + + +def masked_metrics_per_task( + pred: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + task_names: Sequence[str], +) -> Dict[str, Dict[str, float]]: + """ + Per-task metrics using the same masked formulations. + """ + out: Dict[str, Dict[str, float]] = {} + for t, name in enumerate(task_names): + m = mask[:, t] + if m.any(): + rmse = float(masked_rmse(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu()) + mae = float(masked_mae(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu()) + r2 = float(masked_r2(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu()) + else: + rmse = mae = r2 = float("nan") + out[name] = {"rmse": rmse, "mae": mae, "r2": r2} + return out + + +def masked_metrics_by_fidelity( + pred: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + fid_idx: torch.Tensor, + fid_names: Sequence[str], + task_names: Sequence[str], # kept for API parity; not used in overall-by-fid +) -> Dict[str, Dict[str, float]]: + """ + Overall metrics per fidelity (aggregated across tasks). + """ + out: Dict[str, Dict[str, float]] = {} + fid_idx = fid_idx.view(-1).long() + for i, fname in enumerate(fid_names): + sel = (fid_idx == i) + if sel.any(): + p = pred[sel] + y = target[sel] + m = mask[sel] + out[fname] = masked_metrics_overall(p, y, m) + else: + out[fname] = {"rmse": float("nan"), "mae": float("nan"), "r2": float("nan")} + return out + + +# --------------------------------------------------------- +# Multitask, multi-fidelity loss (canonical) +# --------------------------------------------------------- + +def gaussian_nll(mu: torch.Tensor, logvar: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Element-wise Gaussian NLL (no reduction). + Shapes: mu, logvar, target -> [B, T] (or broadcastable). + """ + logvar = torch.as_tensor(logvar, device=mu.device, dtype=mu.dtype) + logvar = logvar.clamp(min=-20.0, max=20.0) # numerical guard + var = torch.exp(logvar) + err2_over_var = (target - mu) ** 2 / var + nll = 0.5 * (err2_over_var + logvar + math.log(2.0 * math.pi)) # [B, T] + return nll + + +def loss_multitask_fidelity( + *, + pred: torch.Tensor, # [B, T] (or means if heteroscedastic) + target: torch.Tensor, # [B, T] + mask: torch.Tensor, # [B, T] bool + fid_idx: torch.Tensor, # [B] long (per-row fidelity index) + fid_loss_w: Sequence[float] | torch.Tensor | None, # [F] weights per fidelity + task_weights: Optional[Sequence[float] | torch.Tensor] = None, # [T] + hetero_logvar: Optional[torch.Tensor] = None, # [B, T] if heteroscedastic head + reduction: Literal["mean", "sum"] = "mean", + task_log_sigma2: Optional[torch.Tensor] = None, # [T] learned homoscedastic uncertainty + balanced: bool = True, +) -> torch.Tensor: + """ + Multi-task, multi-fidelity loss with *balanced per-task reduction* by default. + + - If `hetero_logvar` is given: uses Gaussian NLL per element. + - Applies per-fidelity weights via `fid_idx`. + - Balanced reduction: compute mean loss per task first, then average across tasks + (optionally weight by `task_weights` or learned uncertainty `task_log_sigma2`). + - If `balanced=False`, uses legacy global reduction. + """ + B, T = pred.shape + pred = pred.float() + target = target.float() + mask = mask.bool() + fid_idx = fid_idx.view(-1).long() + + # Task weights (optional) + if task_weights is None: + tw = pred.new_ones(T) # [T] + else: + tw = torch.as_tensor(task_weights, dtype=pred.dtype, device=pred.device) + assert tw.numel() == T, f"task_weights len {tw.numel()} != T {T}" + s = tw.sum().clamp(min=1e-12) + tw = tw * (T / s) # normalize to sum=T for stable scale + + # Fidelity weights + if fid_loss_w is None: + fw = pred.new_ones(int(fid_idx.max().item()) + 1) + else: + fw = torch.as_tensor(fid_loss_w, dtype=pred.dtype, device=pred.device) + w_fid = fw[fid_idx].unsqueeze(1).expand(-1, T) # [B, T] + + # Elementwise loss + if hetero_logvar is not None: + elem_loss = gaussian_nll(pred, hetero_logvar.float(), target) # [B, T] + else: + elem_loss = (pred - target) ** 2 # [B, T] + + if not balanced: + # Legacy global reduction (label-count biased) + w_task = tw.view(1, T).expand(B, -1) + weighted = elem_loss * mask * w_task * w_fid + if reduction == "sum": + return weighted.sum() + denom = (mask * w_task * w_fid).sum().float().clamp(min=1e-12) + return weighted.sum() / denom + + # -------- Balanced per-task reduction -------- + # First compute a per-task average (exclude tw here) + num = (elem_loss * mask * w_fid).sum(dim=0) # [T] + den = (mask * w_fid).sum(dim=0).float().clamp(min=1e-12) # [T] + per_task_loss = num / den # [T] + + # Optional manual task weights AFTER per-task averaging + if task_weights is not None: + per_task_loss = per_task_loss * tw + + # Optional homoscedastic task-uncertainty weighting (Kendall & Gal) + if task_log_sigma2 is not None: + assert task_log_sigma2.numel() == T, f"task_log_sigma2 must be [T], got {task_log_sigma2.shape}" + sigma2 = torch.exp(task_log_sigma2) # [T] + per_task_loss = per_task_loss / (2.0 * sigma2) + 0.5 * torch.log(sigma2) + + if reduction == "sum": + return per_task_loss.sum() + return per_task_loss.mean() + + +# --------------------------------------------------------- +# Curriculum scheduler for EXP fidelity +# --------------------------------------------------------- + +def exp_weight_at_epoch( + epoch: int, + total_epochs: int, + schedule: Literal["none", "linear", "cosine"] = "none", + start: float = 0.6, + end: float = 1.0, +) -> float: + """ + Returns the EXP loss weight for a given epoch under the chosen schedule. + """ + if schedule == "none": + return float(end) + epoch = max(0, min(epoch, total_epochs)) + if total_epochs <= 0: + return float(end) + t = epoch / float(total_epochs) + if schedule == "linear": + return float(start + (end - start) * t) + if schedule == "cosine": + cos_t = 0.5 - 0.5 * math.cos(math.pi * t) # 0->1 smoothly + return float(start + (end - start) * cos_t) + raise ValueError(f"Unknown schedule: {schedule}") + + +def make_fid_loss_weights( + fids: Sequence[str], + base_weights: Optional[Sequence[float]] = None, + exp_weight: Optional[float] = None, +) -> List[float]: + """ + Builds a per-fidelity weight vector aligned with dataset.fids order. + If exp_weight is provided, it overrides the weight for the 'exp' fidelity. + If base_weights is provided, it must match len(fids) and is used as a template. + """ + fids_lc = [f.lower() for f in fids] + F = len(fids_lc) + if base_weights is None: + w = [1.0] * F + else: + assert len(base_weights) == F, f"base_weights len {len(base_weights)} != {F}" + w = [float(x) for x in base_weights] + if exp_weight is not None and "exp" in fids_lc: + idx = fids_lc.index("exp") + w[idx] = float(exp_weight) + return w + + +# --------------------------------------------------------- +# Inference utilities +# --------------------------------------------------------- + +def apply_inverse_transform(pred: torch.Tensor, scaler): + """ + Apply inverse target scaling safely on the same device as pred. + Works for CPU/GPU and legacy scalers. + """ + dev = pred.device + + # Move scaler tensors to pred device if needed + if hasattr(scaler, "mean") and scaler.mean.device != dev: + scaler.mean = scaler.mean.to(dev) + if hasattr(scaler, "std") and scaler.std.device != dev: + scaler.std = scaler.std.to(dev) + if hasattr(scaler, "eps") and scaler.eps is not None and scaler.eps.device != dev: + scaler.eps = scaler.eps.to(dev) + + return scaler.inverse(pred) + + + +def ensure_2d(x: torch.Tensor) -> torch.Tensor: + """Utility to guarantee [B, T] shape for single-task or squeezed outputs.""" + if x.dim() == 1: + return x.unsqueeze(1) + return x + + +# --------------------------------------------------------- +# Simple test harness (optional) +# --------------------------------------------------------- + +if __name__ == "__main__": + # Minimal sanity checks + torch.manual_seed(0) + B, T = 5, 3 + pred = torch.randn(B, T) + targ = torch.randn(B, T) + mask = torch.rand(B, T) > 0.3 + fid_idx = torch.randint(0, 4, (B,)) + fid_w = [1.0, 0.8, 0.6, 0.5] + task_w = [1.0, 2.0, 1.0] + + l1 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=task_w) + l2 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=None) + print("Loss with task weights:", float(l1)) + print("Loss without task weights:", float(l2)) + + m_all = masked_metrics_overall(pred, targ, mask) + print("Overall metrics:", m_all) diff --git a/tests/test_literature_core.py b/tests/test_literature_core.py new file mode 100644 index 0000000000000000000000000000000000000000..1301147a7a733aa2823029f65e9008d9f8296611 --- /dev/null +++ b/tests/test_literature_core.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from literature.evaluation import evaluate_predictions +from literature.property_registry import detect_property_keys, normalize_property_key +from literature.schemas import ContextualizedValue, ExperimentalConditions, PaperMetadata, PaperSource +from literature.standardizer import UnitStandardizer +from src.literature_service import DataPointRepo, LiteraturePipeline, PaperRepo, ProjectRepo, get_database + + +class LiteratureCoreTests(unittest.TestCase): + def test_property_registry_detects_platform_properties(self) -> None: + detected = detect_property_keys("Looking for high thermal conductivity and low density polymers") + self.assertIn("tc", detected) + self.assertIn("rho", detected) + self.assertEqual(normalize_property_key("Young's modulus"), "young") + + def test_standardizer_converts_temperature_and_density(self) -> None: + standardizer = UnitStandardizer() + tg = standardizer.standardize("tg", "150", "C") + rho = standardizer.standardize("rho", "1200", "kg/m^3") + self.assertTrue(tg.success) + self.assertAlmostEqual(tg.value or 0.0, 423.15, places=2) + self.assertEqual(tg.unit, "K") + self.assertTrue(rho.success) + self.assertAlmostEqual(rho.value or 0.0, 1.2, places=3) + self.assertEqual(rho.unit, "g/cm^3") + + def test_evidence_review_is_idempotent(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = Path(tmp_dir) / "app.db" + db = get_database(db_path) + projects = ProjectRepo(db) + papers = PaperRepo(db) + evidence = DataPointRepo(db) + + project = projects.create_project("Test literature") + paper = papers.upsert_from_metadata( + project["id"], + PaperMetadata( + id="manual_test-paper", + title="Polyimide thermal conductivity", + source=PaperSource.MANUAL, + doi="10.1000/example", + is_open_access=True, + ), + ) + inserted = evidence.insert_points( + project["id"], + paper["id"], + None, + [ + ContextualizedValue( + polymer_name="Polyimide", + property_name="tc", + raw_value="0.25", + raw_unit="W/(m*K)", + conditions=ExperimentalConditions(measurement_method="laser flash"), + source_quote="Measured thermal conductivity was 0.25 W/(m*K) by laser flash analysis.", + source_location="Table 1", + extraction_confidence=0.92, + ) + ], + extractor_version="test-v1", + ) + self.assertEqual(inserted, 1) + + staged = evidence.list_evidence(project["id"]) + self.assertEqual(len(staged), 1) + first = staged[0] + evidence.update_review(first["id"], validation_status="approved", reviewer_note="looks good", action="approve") + evidence.update_review(first["id"], validation_status="approved", reviewer_note="looks good", action="approve") + refreshed = evidence.get_evidence(first["id"]) + self.assertEqual(refreshed["review_status"], "approved") + + def test_evaluation_harness_reports_scores(self) -> None: + gold = [ + { + "material_name": "Polyimide", + "property_key": "tg", + "raw_value": "450", + "raw_unit": "K", + "method": "DSC", + "evidence_quote": "The Tg was 450 K.", + } + ] + pred = [ + { + "material_name": "Polyimide", + "property_name": "glass transition temperature", + "raw_value": "450", + "raw_unit": "K", + "method": "DSC", + "source_quote": "The Tg was 450 K.", + } + ] + metrics = evaluate_predictions(gold, pred) + self.assertGreater(metrics["field_metrics"]["property_key"]["f1"], 0.99) + self.assertGreater(metrics["relation_level"]["f1"], 0.99) + self.assertGreater(metrics["source_grounding_hit_rate"], 0.99) + + def test_run_extraction_skips_cleanly_without_backend(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = Path(tmp_dir) / "app.db" + pipeline = LiteraturePipeline(db_path=str(db_path)) + project = pipeline.projects.create_project("Test extraction skip") + paper = pipeline.papers.upsert_from_metadata( + project["id"], + PaperMetadata( + id="manual_skip-paper", + title="Polyimide Tg paper", + source=PaperSource.MANUAL, + is_open_access=True, + ), + ) + + contextual_instance = MagicMock() + contextual_instance.is_configured.return_value = False + legacy_instance = MagicMock() + legacy_instance.can_attempt_extraction.return_value = False + + with patch("src.literature_service.pipeline.ContextualizedExtractor", return_value=contextual_instance), patch( + "src.literature_service.pipeline.DataExtractor", + return_value=legacy_instance, + ): + stats = pipeline.run_extraction( + project["id"], + run_id=None, + paper_rows=[paper], + target_properties=["tg"], + ) + + self.assertEqual(stats["skip_reason"], "no_extraction_backend") + self.assertEqual(stats["skipped_unconfigured"], 1) + job = pipeline.extraction_jobs.get_job(project["id"], paper["id"], "production-v1") + self.assertIsNotNone(job) + self.assertEqual(job["status"], "skipped") + self.assertIn("no LLM or PageIndex backend", job["error_message"]) + + def test_paper_cards_surface_extraction_skipped_status(self) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + db_path = Path(tmp_dir) / "app.db" + pipeline = LiteraturePipeline(db_path=str(db_path)) + project = pipeline.projects.create_project("Test skipped cards") + paper = pipeline.papers.upsert_from_metadata( + project["id"], + PaperMetadata( + id="manual_skip-card", + title="Polyimide density paper", + source=PaperSource.MANUAL, + is_open_access=True, + ), + ) + pipeline.extraction_jobs.upsert_job( + project["id"], + paper["id"], + "production-v1", + status="skipped", + error_message="Structured extraction skipped: no backend configured.", + ) + + cards = pipeline.get_paper_cards(project["id"]) + self.assertEqual(len(cards), 1) + self.assertEqual(cards[0].background_status, "extraction_skipped") + + +if __name__ == "__main__": + unittest.main()