Automated model and inference script upload v2
Browse files- README.md +68 -0
- models/atmosphere_category_tuned_feature_columns.joblib +3 -0
- models/atmosphere_category_tuned_imputer.joblib +3 -0
- models/atmosphere_category_tuned_label_encoder.joblib +3 -0
- models/atmosphere_category_tuned_lgbm_model.joblib +3 -0
- models/atmosphere_category_tuned_scaler.joblib +3 -0
- models/df_elements_processed.pkl +3 -0
- models/temperature_bin_tuned_feature_columns.joblib +3 -0
- models/temperature_bin_tuned_imputer.joblib +3 -0
- models/temperature_bin_tuned_label_encoder.joblib +3 -0
- models/temperature_bin_tuned_lgbm_model.joblib +3 -0
- models/temperature_bin_tuned_scaler.joblib +3 -0
- requirements.txt +8 -0
- src/__init__.py +1 -0
- src/constants.py +40 -0
- src/feature_engineering_utils.py +136 -0
- src/inference.py +224 -0
- src/process_feature_utils.py +192 -0
README.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
---
|
| 3 |
+
license: mit
|
| 4 |
+
language: en
|
| 5 |
+
tags:
|
| 6 |
+
- materials science
|
| 7 |
+
- synthesis prediction
|
| 8 |
+
- lightgbm
|
| 9 |
+
- cheminformatics
|
| 10 |
+
datasets: []
|
| 11 |
+
metrics:
|
| 12 |
+
- accuracy
|
| 13 |
+
- f1
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Synthesis Condition Predictor
|
| 17 |
+
|
| 18 |
+
This model predicts optimal temperature bins and atmosphere categories for inorganic material synthesis.
|
| 19 |
+
It was trained on a dataset of text-mined synthesis procedures.
|
| 20 |
+
|
| 21 |
+
**Models Included:**
|
| 22 |
+
* Temperature Bin Prediction (LightGBM)
|
| 23 |
+
* Atmosphere Category Prediction (LightGBM)
|
| 24 |
+
|
| 25 |
+
**Intended Use:**
|
| 26 |
+
To assist researchers in designing synthesis experiments by predicting key process parameters.
|
| 27 |
+
Input a target material, precursors, and basic operational details to get predictions.
|
| 28 |
+
|
| 29 |
+
**How to Use:**
|
| 30 |
+
```python
|
| 31 |
+
# Ensure your inference script and its dependencies are in the PYTHONPATH
|
| 32 |
+
# Example, if your repo is named 'synthesis_predictor_hf_repo' and it's in your path:
|
| 33 |
+
# from synthesis_predictor_hf_repo.src.inference import predict_synthesis_outcome, load_all_artifacts_once
|
| 34 |
+
|
| 35 |
+
# Or, if running from a cloned repo where 'src' is a subdirectory:
|
| 36 |
+
# from src.inference import predict_synthesis_outcome, load_all_artifacts_once
|
| 37 |
+
|
| 38 |
+
# if not load_all_artifacts_once():
|
| 39 |
+
# print("Failed to load model artifacts.")
|
| 40 |
+
# else:
|
| 41 |
+
# raw_input_example = {
|
| 42 |
+
# 'target_formula_raw': "YBa2Cu3O7",
|
| 43 |
+
# 'precursor_formulas_raw': ["Y2O3", "BaCO3", "CuO"],
|
| 44 |
+
# 'operations_simplified_list': [
|
| 45 |
+
# {'type': 'MixingOperation', 'string': 'Ball milling for 2h', 'conditions': {'duration': [{'value':2, 'unit':'h'}]}},
|
| 46 |
+
# {'type': 'HeatingOperation', 'string': 'Calcined at 920C for 10h in air',
|
| 47 |
+
# 'conditions': {'heating_temperature': [{'value':920}], 'heating_time': [{'value':10}], 'atmosphere':'air'}},
|
| 48 |
+
# {'type': 'HeatingOperation', 'string': 'Sintered at 950C for 20h in O2',
|
| 49 |
+
# 'conditions': {'heating_temperature': [{'value':950}], 'heating_time': [{'value':20}], 'atmosphere':'Oxygen'}}
|
| 50 |
+
# ],
|
| 51 |
+
# 'reactants_coeffs': [("Y2O3", 0.5), ("BaCO3", 2.0), ("CuO", 3.0)], # Example, adjust as needed
|
| 52 |
+
# 'products_coeffs': [("YBa2Cu3O7", 1.0)] # Example
|
| 53 |
+
# }
|
| 54 |
+
# predictions = predict_synthesis_outcome(raw_input_example)
|
| 55 |
+
# print(predictions)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
**Limitations:**
|
| 59 |
+
* The model's accuracy is around 68-72%.
|
| 60 |
+
* Predictions are based on patterns in the training data and may not generalize to all chemical systems.
|
| 61 |
+
* The feature engineering for process parameters in the inference script relies on the user providing an `operations_simplified_list` that can be parsed by the internal logic. The quality of these inputs directly affects prediction accuracy.
|
| 62 |
+
|
| 63 |
+
**Training Data:**
|
| 64 |
+
The model was trained on a proprietary dataset of text-mined inorganic synthesis procedures.
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
**Evaluation Results:**
|
| 68 |
+
()
|
models/atmosphere_category_tuned_feature_columns.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02b835987564e8176b65df3dc8b01231bbe3d4b070e9b7782e84e1c681ca40a0
|
| 3 |
+
size 42289
|
models/atmosphere_category_tuned_imputer.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78fb389f43864a52c88abdf683efba3b6d2201944c95a61e29d4c7926da7f5df
|
| 3 |
+
size 50367
|
models/atmosphere_category_tuned_label_encoder.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28abe2833f4f1f63f8edc5f441003897e5b283ef62c4fc5403988192a716f7f2
|
| 3 |
+
size 546
|
models/atmosphere_category_tuned_lgbm_model.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7f6ba5d42c8deaadd9c0ffe3c1cb7ad4cd55f17c37fd7c6b32e3c131a1e28cf
|
| 3 |
+
size 10168460
|
models/atmosphere_category_tuned_scaler.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:060c0cf8bed3af7c2fe4d01f5fb00dc65fff1adfc7f4fae37ce2318dde0bb810
|
| 3 |
+
size 65791
|
models/df_elements_processed.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:239fb692b5bbe23bbadc166427c25be0b139d2037e5266ce72061306d7dfe589
|
| 3 |
+
size 87794
|
models/temperature_bin_tuned_feature_columns.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:acc884465ac46f5693ef039d4aa6b4921f7b8c003e11d8e53c92086471dec81b
|
| 3 |
+
size 42314
|
models/temperature_bin_tuned_imputer.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98c6f43618fc49cce41d4c72402b9d5bbfd549b43cb345f2443cf9dc47468ec9
|
| 3 |
+
size 50223
|
models/temperature_bin_tuned_label_encoder.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d750eb6f78f9cade3b8cdd74308bbea8da6a3e96484b0e3966b065873e7d608a
|
| 3 |
+
size 576
|
models/temperature_bin_tuned_lgbm_model.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd02f53571b21580d42fcf6dcc63584d43c1d0b6273129a7b940a97d3a02d9d8
|
| 3 |
+
size 17173636
|
models/temperature_bin_tuned_scaler.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f9903d5d60c279bcd459b9ed3aaa1e5698660e399aea0b19a23a472aa4f44840
|
| 3 |
+
size 65599
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
pandas
|
| 3 |
+
numpy
|
| 4 |
+
scikit-learn
|
| 5 |
+
lightgbm
|
| 6 |
+
joblib
|
| 7 |
+
pymatgen
|
| 8 |
+
# matminer # Optional, if MAGPIE_FEATURIZER is used directly at inference and not just for labels
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Makes 'src' a package\n
|
src/constants.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from pymatgen.core import Element as PymatgenElement
|
| 3 |
+
|
| 4 |
+
KNOWN_ELEMENT_SYMBOLS = {el.symbol for el in PymatgenElement}
|
| 5 |
+
|
| 6 |
+
ATMOSPHERE_CONFIG = {
|
| 7 |
+
"patterns": [
|
| 8 |
+
(r'\b(air)\b', 'Air', 'Oxidizing'), (r'\b(O\s?2|oxygen)\b', 'O2', 'Oxidizing'),
|
| 9 |
+
(r'\b(Ar|argon)\b', 'Ar', 'Inert'), (r'\b(N\s?2|nitrogen)\b', 'N2', 'Inert'),
|
| 10 |
+
(r'\b(H\s?2/N\s?2|N\s?2/H\s?2|forming\s*gas)\b', 'FormingGas(N2/H2)', 'Reducing'),
|
| 11 |
+
(r'\b(Ar/H\s?2|H\s?2/Ar)\b', 'Ar/H2', 'Reducing'), (r'\b(H\s?2|hydrogen)\b', 'H2', 'Reducing'),
|
| 12 |
+
(r'\b(vacuum)\b', 'Vacuum', 'Vacuum'), (r'\b(He|helium)\b', 'He', 'Inert'),
|
| 13 |
+
(r'\b(CO2|carbon\s*dioxide)\b', 'CO2', 'Neutral/Other'),
|
| 14 |
+
(r'\b(CO|carbon\s*monoxide)\b', 'CO', 'Reducing'), (r'\b(NH3|ammonia)\b', 'NH3', 'Reducing/Other'),
|
| 15 |
+
], "default_specific": "unknown_atm_specific", "default_category": "Unknown_Atm_Category"
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
MIXING_METHOD_CONFIG = {
|
| 19 |
+
"patterns": [
|
| 20 |
+
(r'\b(ball\s*mill(?:ing)?)\b', 'ball_milling'), (r'\b(grind(?:ing)?|ground|pulverized|milled)\b', 'grinding'),
|
| 21 |
+
(r'\b(solution|wet|homogeni[sz]ation|slurr(y|ies))\b', 'wet_method'),
|
| 22 |
+
(r'\b(solid-state|solid\s*state(\s*reaction)?)\b', 'solid_state_mixing'),
|
| 23 |
+
(r'\b(stir(?:ring)?)\b', 'stirring'), (r'\b(sonica(te|tion|ted))\b', 'sonication'),
|
| 24 |
+
(r'\b(planetary\s*mill(?:ing)?)\b', 'planetary_milling'), (r'\b(attritor\s*mill(?:ing)?)\b', 'attritor_milling'),
|
| 25 |
+
(r'\b(shaker\s*mill(?:ing)?)\b', 'shaker_milling'), (r'\b(mortar\s*(and\s*pestle)?)\b', 'mortar_pestle'),
|
| 26 |
+
], "default_method": "unknown_mix_method"
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
MAGPIE_LABELS = []
|
| 30 |
+
matminer_available = False
|
| 31 |
+
MAGPIE_FEATURIZER = None
|
| 32 |
+
try:
|
| 33 |
+
from matminer.featurizers.composition import ElementProperty
|
| 34 |
+
matminer_available = True
|
| 35 |
+
MAGPIE_FEATURIZER = ElementProperty.from_preset("magpie", impute_nan=True)
|
| 36 |
+
MAGPIE_LABELS = [f'magpie_{label.replace(" ", "_")}' for label in MAGPIE_FEATURIZER.feature_labels()]
|
| 37 |
+
except ImportError:
|
| 38 |
+
pass
|
| 39 |
+
except Exception:
|
| 40 |
+
pass
|
src/feature_engineering_utils.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pymatgen.core import Composition, Element as PymatgenElement
|
| 5 |
+
import ast
|
| 6 |
+
import re
|
| 7 |
+
import logging
|
| 8 |
+
from .constants import KNOWN_ELEMENT_SYMBOLS, MAGPIE_FEATURIZER, MAGPIE_LABELS, matminer_available
|
| 9 |
+
|
| 10 |
+
# --- Formula Cleaning and Standardization ---
|
| 11 |
+
def clean_formula_string_advanced(formula_str_original):
|
| 12 |
+
if not isinstance(formula_str_original, str): return formula_str_original
|
| 13 |
+
cleaned = formula_str_original.strip()
|
| 14 |
+
paren_match = re.search(r'\(([^()]+)\)[^()]*$', cleaned)
|
| 15 |
+
if paren_match:
|
| 16 |
+
potential_formula_in_parens = paren_match.group(1).strip()
|
| 17 |
+
part_before_parens = cleaned[:paren_match.start()].strip()
|
| 18 |
+
if len(potential_formula_in_parens) > 1 and re.search(r"[A-Z]", potential_formula_in_parens) and re.fullmatch(r"[A-Za-z0-9\.\(\)\[\]]+", potential_formula_in_parens):
|
| 19 |
+
if not part_before_parens or " " in part_before_parens or len(part_before_parens) > len(potential_formula_in_parens) + 5 or (part_before_parens.isalpha() and len(part_before_parens)>4) or re.fullmatch(r"\d+(\.\d+)?", part_before_parens) or re.fullmatch(r"\d*N", part_before_parens, re.IGNORECASE):
|
| 20 |
+
cleaned = potential_formula_in_parens
|
| 21 |
+
elif not re.search(r"[A-Za-z]", part_before_parens) and re.search(r"\d", part_before_parens):
|
| 22 |
+
cleaned = potential_formula_in_parens
|
| 23 |
+
cleaned = re.sub(r"^[αΑβΒγΓδΔεΕζΖηΗθΘιΙκΚλΛμΜνΝξΞοΟπΠρΡσΣτΤυΥφΦχΧψΨωΩ]-", "", cleaned)
|
| 24 |
+
cleaned = re.sub(r"^[a-zA-Z]-", "", cleaned)
|
| 25 |
+
cleaned = re.sub(r"[·*]\s*\d*(\.\d+)?[nNxX]?\s*H2O", "", cleaned)
|
| 26 |
+
cleaned = re.sub(r"\s*\(\s*H2O\s*\)\s*\d*(\.\d+)?", "", cleaned)
|
| 27 |
+
cleaned = re.sub(r"·\s*H2O", "", cleaned)
|
| 28 |
+
cleaned = re.sub(r"\s*\(\s*(?:\d*N|\d+(?:\.\d+)?%?|solution|gas|powder|aq|amorphous|amorph|polytype|phase|\d{1,4})\s*\)\s*$", "", cleaned, flags=re.IGNORECASE)
|
| 29 |
+
cleaned = re.sub(r"^\s*\(\s*\d+(\.\d+)?\s*\)\s*(?=[A-Z])", "", cleaned)
|
| 30 |
+
def replace_frac(match):
|
| 31 |
+
try: num = float(match.group(1)); den = float(match.group(2)); return str(round(num / den, 4)) if den != 0 else match.group(0)
|
| 32 |
+
except: return match.group(0)
|
| 33 |
+
cleaned = re.sub(r"(?<=[A-Za-z\d\)])(\d+)\s*/\s*(\d+)", replace_frac, cleaned)
|
| 34 |
+
cleaned = re.sub(r"^(\d+)\s*/\s*(\d+)", replace_frac, cleaned)
|
| 35 |
+
cleaned = re.sub(r"^\s*\(?[a-zA-Z\s]+\)?-", "", cleaned); cleaned = re.sub(r"^[a-zA-Z]+-", "", cleaned)
|
| 36 |
+
cleaned = cleaned.strip(" .,;·*()")
|
| 37 |
+
return cleaned
|
| 38 |
+
|
| 39 |
+
def is_plausible_formula_for_pymatgen(cleaned_formula_str, entry_identifier):
|
| 40 |
+
if not isinstance(cleaned_formula_str, str) or not cleaned_formula_str.strip(): return False
|
| 41 |
+
if '+' in cleaned_formula_str or '==' in cleaned_formula_str or '->' in cleaned_formula_str or ';' in cleaned_formula_str: return False
|
| 42 |
+
variable_indicators = [r"[A-Za-z]\d*\s*[-+*]\s*[xyzδδn]", r"[xyzδδn]\s*[-+*]", r"[A-Za-z]\d*\(\s*\d*\s*[-+]\s*[xyzδδn]\s*\)?", r"(?<![A-Za-z])(?:[1-9]\d*|0)?\.\d*[xyzδδn]", r"[xyzδδn]\d+", r"[A-Za-z]\s*[xyzδδn]\s*\d*", r"1-[xyzδδn]",]
|
| 43 |
+
variable_char_pattern = r"(?i)(?<![A-Z])([xyzδδn])(?![a-z])"
|
| 44 |
+
for pattern in variable_indicators:
|
| 45 |
+
if re.search(pattern, cleaned_formula_str, re.IGNORECASE):
|
| 46 |
+
possible_vars = re.findall(variable_char_pattern, cleaned_formula_str)
|
| 47 |
+
if any(pv.upper() not in KNOWN_ELEMENT_SYMBOLS for pv in possible_vars if len(pv)==1): return False
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
def standardize_chemical_formula(raw_formula_str, entry_identifier="Unknown_Entry"):
|
| 51 |
+
if not isinstance(raw_formula_str, str) or not raw_formula_str.strip(): return None
|
| 52 |
+
cleaned_formula_str = clean_formula_string_advanced(raw_formula_str)
|
| 53 |
+
if not cleaned_formula_str: return None
|
| 54 |
+
if is_plausible_formula_for_pymatgen(cleaned_formula_str, f"{entry_identifier} (Original: '{raw_formula_str}', Cleaned: '{cleaned_formula_str}')"):
|
| 55 |
+
try:
|
| 56 |
+
comp_formula_for_pymatgen = cleaned_formula_str.replace(" ", "")
|
| 57 |
+
if not comp_formula_for_pymatgen: return None
|
| 58 |
+
comp = Composition(comp_formula_for_pymatgen)
|
| 59 |
+
if all(el.symbol in KNOWN_ELEMENT_SYMBOLS for el in comp.elements): return comp.get_reduced_formula_and_factor()[0].replace(" ", "")
|
| 60 |
+
except Exception: pass
|
| 61 |
+
extracted_elements = {el for el in re.findall(r"([A-Z][a-z]?)", cleaned_formula_str) if el in KNOWN_ELEMENT_SYMBOLS}
|
| 62 |
+
if extracted_elements: return {'type': 'elements_only', 'elements': extracted_elements, 'original_cleaned': cleaned_formula_str}
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
def get_valence_features(valences_input, entry_identifier="Unknown_Entry"):
|
| 66 |
+
valences_list = valences_input
|
| 67 |
+
if isinstance(valences_input, str):
|
| 68 |
+
try: valences_list = ast.literal_eval(valences_input)
|
| 69 |
+
except (ValueError, SyntaxError, TypeError): valences_list = []
|
| 70 |
+
if not isinstance(valences_list, list) or not valences_list: return {'avg_valence': np.nan, 'min_valence': np.nan, 'max_valence': np.nan}
|
| 71 |
+
numeric_valences = [v for v in valences_list if isinstance(v, (int, float))]
|
| 72 |
+
if not numeric_valences: return {'avg_valence': np.nan, 'min_valence': np.nan, 'max_valence': np.nan}
|
| 73 |
+
return {'avg_valence': np.mean(numeric_valences), 'min_valence': np.min(numeric_valences), 'max_valence': np.max(numeric_valences)}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def generate_compositional_features(formula_input, df_elements_processed, entry_identifier="Unknown_Formula"):
|
| 77 |
+
default_feature_dict = {'is_stoichiometric_formula': False, 'num_elements_in_formula': 0}
|
| 78 |
+
basic_props = ['avg_atomic_weight', 'avg_electronegativity', 'avg_atomic_radius', 'avg_melting_point', 'avg_density', 'avg_specific_heat', 'avg_thermal_conductivity', 'avg_heat_of_fusion', 'sum_atomic_weight', 'range_electronegativity', 'min_electronegativity', 'max_electronegativity', 'var_atomic_radius', 'min_atomic_radius', 'max_atomic_radius', 'avg_valence_of_comp', 'avg_est_valence_electrons']
|
| 79 |
+
unweighted_props = [f'avg_{prop.lower()}_unweighted' for prop in ['Atomic_Weight', 'Electronegativity', 'Atomic_Radius', 'Melting_Point', 'Density', 'avg_valence', 'valence_electrons_estimated']] + [f'min_{prop.lower()}_unweighted' for prop in ['Atomic_Weight', 'Electronegativity', 'Atomic_Radius', 'Melting_Point', 'Density', 'avg_valence', 'valence_electrons_estimated']] + [f'max_{prop.lower()}_unweighted' for prop in ['Atomic_Weight', 'Electronegativity', 'Atomic_Radius', 'Melting_Point', 'Density', 'avg_valence', 'valence_electrons_estimated']] + [f'var_{prop.lower()}_unweighted' for prop in ['Atomic_Weight', 'Electronegativity', 'Atomic_Radius', 'Melting_Point', 'Density', 'avg_valence', 'valence_electrons_estimated']]
|
| 80 |
+
for k in basic_props + unweighted_props: default_feature_dict[k] = np.nan
|
| 81 |
+
if matminer_available and MAGPIE_LABELS:
|
| 82 |
+
for label in MAGPIE_LABELS: default_feature_dict[label] = np.nan
|
| 83 |
+
|
| 84 |
+
if formula_input is None: return default_feature_dict.copy()
|
| 85 |
+
features = {}
|
| 86 |
+
if isinstance(formula_input, str):
|
| 87 |
+
try:
|
| 88 |
+
comp = Composition(formula_input); el_amt_dict = comp.get_el_amt_dict(); total_atoms = sum(el_amt_dict.values())
|
| 89 |
+
if total_atoms == 0: return {**default_feature_dict, 'is_stoichiometric_formula': False}
|
| 90 |
+
features['is_stoichiometric_formula'] = True; features['num_elements_in_formula'] = len(el_amt_dict)
|
| 91 |
+
props_for_avg_mapping = {'avg_atomic_weight': 'Atomic_Weight', 'avg_electronegativity': 'Electronegativity', 'avg_atomic_radius': 'Atomic_Radius', 'avg_melting_point': 'Melting_Point', 'avg_density': 'Density', 'avg_specific_heat': 'Specific_Heat', 'avg_thermal_conductivity': 'Thermal_Conductivity', 'avg_heat_of_fusion': 'Heat_of_Fusion', 'avg_valence_of_comp': 'avg_valence', 'avg_est_valence_electrons': 'valence_electrons_estimated'}
|
| 92 |
+
element_values_for_stats_mapping = {'electronegativity': 'Electronegativity', 'atomic_radius': 'Atomic_Radius'}
|
| 93 |
+
current_props_for_avg = {k: [] for k in props_for_avg_mapping.keys()}; current_element_values_for_stats = {k: [] for k in element_values_for_stats_mapping.keys()}; valid_elements_for_avg_count = {k: 0 for k in props_for_avg_mapping.keys()}
|
| 94 |
+
for el_obj, amt in el_amt_dict.items():
|
| 95 |
+
el_symbol_str = el_obj.symbol if isinstance(el_obj, PymatgenElement) else str(el_obj)
|
| 96 |
+
if el_symbol_str not in KNOWN_ELEMENT_SYMBOLS: continue
|
| 97 |
+
if el_symbol_str in df_elements_processed.index:
|
| 98 |
+
el_props_series = df_elements_processed.loc[el_symbol_str]
|
| 99 |
+
for feat_key, elem_col_name in props_for_avg_mapping.items():
|
| 100 |
+
val = el_props_series.get(elem_col_name, np.nan)
|
| 101 |
+
if pd.notna(val): current_props_for_avg[feat_key].append(val * amt); valid_elements_for_avg_count[feat_key] += amt
|
| 102 |
+
for feat_key, elem_col_name in element_values_for_stats_mapping.items():
|
| 103 |
+
val = el_props_series.get(elem_col_name, np.nan)
|
| 104 |
+
if pd.notna(val): current_element_values_for_stats[feat_key].extend([val] * int(round(amt)))
|
| 105 |
+
for key, val_list in current_props_for_avg.items(): features[key] = np.nansum(val_list) / valid_elements_for_avg_count[key] if valid_elements_for_avg_count[key] > 0 else np.nan
|
| 106 |
+
features['sum_atomic_weight'] = comp.weight
|
| 107 |
+
for key, val_list in current_element_values_for_stats.items():
|
| 108 |
+
clean_val_list = [v for v in val_list if pd.notna(v)]
|
| 109 |
+
if clean_val_list: features[f'range_{key}'] = np.max(clean_val_list) - np.min(clean_val_list); features[f'min_{key}'] = np.min(clean_val_list); features[f'max_{key}'] = np.max(clean_val_list); features[f'var_{key}'] = np.var(clean_val_list)
|
| 110 |
+
else:
|
| 111 |
+
for stat in ['range_', 'min_', 'max_', 'var_']: features[f'{stat}{key}'] = np.nan
|
| 112 |
+
if matminer_available and MAGPIE_FEATURIZER:
|
| 113 |
+
try:
|
| 114 |
+
magpie_vals = MAGPIE_FEATURIZER.featurize(comp)
|
| 115 |
+
for i, label in enumerate(MAGPIE_LABELS): features[label] = magpie_vals[i]
|
| 116 |
+
except: pass
|
| 117 |
+
except: features['is_stoichiometric_formula'] = False
|
| 118 |
+
elif isinstance(formula_input, dict) and formula_input.get('type') == 'elements_only':
|
| 119 |
+
features['is_stoichiometric_formula'] = False
|
| 120 |
+
elements_present = formula_input.get('elements', set())
|
| 121 |
+
valid_elements = [el for el in elements_present if el in df_elements_processed.index]
|
| 122 |
+
features['num_elements_in_formula'] = len(valid_elements)
|
| 123 |
+
if valid_elements:
|
| 124 |
+
element_props_subset = df_elements_processed.loc[valid_elements]
|
| 125 |
+
unweighted_props_to_calc = ['Atomic_Weight', 'Electronegativity', 'Atomic_Radius', 'Melting_Point', 'Density', 'avg_valence', 'valence_electrons_estimated']
|
| 126 |
+
for prop_col in unweighted_props_to_calc:
|
| 127 |
+
if prop_col in element_props_subset.columns:
|
| 128 |
+
clean_vals = element_props_subset[prop_col].dropna()
|
| 129 |
+
if not clean_vals.empty:
|
| 130 |
+
features[f'avg_{prop_col.lower()}_unweighted'] = clean_vals.mean()
|
| 131 |
+
features[f'min_{prop_col.lower()}_unweighted'] = clean_vals.min()
|
| 132 |
+
features[f'max_{prop_col.lower()}_unweighted'] = clean_vals.max()
|
| 133 |
+
features[f'var_{prop_col.lower()}_unweighted'] = clean_vals.var()
|
| 134 |
+
|
| 135 |
+
final_features = default_feature_dict.copy(); final_features.update(features)
|
| 136 |
+
return final_features
|
src/inference.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import joblib
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
from pymatgen.core import Composition
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
from .constants import KNOWN_ELEMENT_SYMBOLS, ATMOSPHERE_CONFIG, MIXING_METHOD_CONFIG, MAGPIE_FEATURIZER, MAGPIE_LABELS, matminer_available
|
| 11 |
+
from .feature_engineering_utils import standardize_chemical_formula, generate_compositional_features
|
| 12 |
+
from .process_feature_utils import generate_process_features_for_input, generate_stoichiometry_features_for_input
|
| 13 |
+
|
| 14 |
+
MODEL_DIR = "../models"
|
| 15 |
+
PREPROCESSOR_DIR = "../models"
|
| 16 |
+
ELEMENTAL_DATA_PATH = os.path.join(MODEL_DIR, "df_elements_processed.pkl")
|
| 17 |
+
|
| 18 |
+
ESSENTIAL_OBJECTS = {}
|
| 19 |
+
DF_ELEMENTS_PROCESSED_GLOBAL = None
|
| 20 |
+
|
| 21 |
+
def load_all_artifacts_once():
|
| 22 |
+
global DF_ELEMENTS_PROCESSED_GLOBAL, ESSENTIAL_OBJECTS, matminer_available, MAGPIE_FEATURIZER, MAGPIE_LABELS
|
| 23 |
+
if ESSENTIAL_OBJECTS.get("loaded_successfully"):
|
| 24 |
+
logging.info("Artifacts already loaded.")
|
| 25 |
+
return True
|
| 26 |
+
|
| 27 |
+
logging.info("--- Loading Essential Artifacts for Prediction ---")
|
| 28 |
+
script_dir = os.path.dirname(__file__)
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
elemental_data_full_path = os.path.join(script_dir, ELEMENTAL_DATA_PATH)
|
| 32 |
+
DF_ELEMENTS_PROCESSED_GLOBAL = pd.read_pickle(elemental_data_full_path)
|
| 33 |
+
ESSENTIAL_OBJECTS["elemental_data"] = DF_ELEMENTS_PROCESSED_GLOBAL
|
| 34 |
+
logging.info(f"Loaded processed elemental data from {elemental_data_full_path}")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
logging.critical(f"CRITICAL: Error loading elemental data from {elemental_data_full_path}: {e}")
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
if not matminer_available: # Attempt to re-init if constants.py didn't catch it
|
| 40 |
+
try:
|
| 41 |
+
from matminer.featurizers.composition import ElementProperty
|
| 42 |
+
MAGPIE_FEATURIZER = ElementProperty.from_preset("magpie", impute_nan=True)
|
| 43 |
+
MAGPIE_LABELS = [f'magpie_{label.replace(" ", "_")}' for label in MAGPIE_FEATURIZER.feature_labels()]
|
| 44 |
+
matminer_available = True
|
| 45 |
+
logging.info("Matminer re-initialized in inference script.")
|
| 46 |
+
except:
|
| 47 |
+
logging.warning("Matminer could not be re-initialized in inference script.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
ESSENTIAL_OBJECTS["models"] = {}
|
| 51 |
+
ESSENTIAL_OBJECTS["encoders"] = {}
|
| 52 |
+
ESSENTIAL_OBJECTS["imputers"] = {}
|
| 53 |
+
ESSENTIAL_OBJECTS["scalers"] = {}
|
| 54 |
+
ESSENTIAL_OBJECTS["feature_columns"] = {}
|
| 55 |
+
|
| 56 |
+
all_loaded_successfully = True
|
| 57 |
+
for model_type_key in ["temperature_bin", "atmosphere_category"]:
|
| 58 |
+
model_artifact_name = f"{model_type_key}_tuned"
|
| 59 |
+
try:
|
| 60 |
+
ESSENTIAL_OBJECTS["models"][model_type_key] = joblib.load(os.path.join(script_dir, MODEL_DIR, f"{model_artifact_name}_lgbm_model.joblib"))
|
| 61 |
+
ESSENTIAL_OBJECTS["encoders"][model_type_key] = joblib.load(os.path.join(script_dir, MODEL_DIR, f"{model_artifact_name}_label_encoder.joblib"))
|
| 62 |
+
ESSENTIAL_OBJECTS["imputers"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_imputer.joblib"))
|
| 63 |
+
ESSENTIAL_OBJECTS["scalers"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_scaler.joblib"))
|
| 64 |
+
ESSENTIAL_OBJECTS["feature_columns"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_feature_columns.joblib"))
|
| 65 |
+
logging.info(f"Loaded artifacts for {model_artifact_name} model.")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logging.error(f"Error loading one or more artifacts for '{model_artifact_name}': {e}. Predictions for it may fail.")
|
| 68 |
+
ESSENTIAL_OBJECTS["models"][model_type_key] = None
|
| 69 |
+
all_loaded_successfully = False
|
| 70 |
+
|
| 71 |
+
ESSENTIAL_OBJECTS["loaded_successfully"] = all_loaded_successfully
|
| 72 |
+
return all_loaded_successfully
|
| 73 |
+
|
| 74 |
+
def create_feature_vector_for_prediction(raw_synthesis_input, model_target_name):
|
| 75 |
+
global DF_ELEMENTS_PROCESSED_GLOBAL, ESSENTIAL_OBJECTS
|
| 76 |
+
|
| 77 |
+
if DF_ELEMENTS_PROCESSED_GLOBAL is None:
|
| 78 |
+
logging.error("Elemental data not loaded. Call load_all_artifacts_once() first.")
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
expected_feature_cols = ESSENTIAL_OBJECTS["feature_columns"].get(model_target_name)
|
| 82 |
+
if not expected_feature_cols:
|
| 83 |
+
logging.error(f"Feature column list for '{model_target_name}' not found in loaded artifacts.")
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
feature_dict = {col: (0 if col.startswith(("ops_", "proc_has_", "elem_block_")) or "is_stoichiometric" in col or "is_elements_only" in col else np.nan) for col in expected_feature_cols}
|
| 87 |
+
|
| 88 |
+
# Target Compositional Features
|
| 89 |
+
std_target_output = standardize_chemical_formula(raw_synthesis_input.get('target_formula_raw'), "predict_target")
|
| 90 |
+
target_comp_feats = generate_compositional_features(std_target_output, DF_ELEMENTS_PROCESSED_GLOBAL, "predict_target_comp")
|
| 91 |
+
for k, v in target_comp_feats.items():
|
| 92 |
+
feature_key = f'target_{k}'
|
| 93 |
+
if feature_key in feature_dict: feature_dict[feature_key] = v
|
| 94 |
+
|
| 95 |
+
# Precursor Compositional Features
|
| 96 |
+
precursor_formulas_raw = raw_synthesis_input.get('precursor_formulas_raw', [])
|
| 97 |
+
std_precursors_outputs = [standardize_chemical_formula(p, f"predict_prec_{i}") for i, p in enumerate(precursor_formulas_raw)]
|
| 98 |
+
num_valid_precursors, num_stoich_precursors, num_elements_only_precursors = 0,0,0
|
| 99 |
+
precursor_comp_feats_list = []
|
| 100 |
+
for std_p_output in std_precursors_outputs:
|
| 101 |
+
if std_p_output is not None:
|
| 102 |
+
num_valid_precursors += 1
|
| 103 |
+
if isinstance(std_p_output, str): num_stoich_precursors += 1
|
| 104 |
+
elif isinstance(std_p_output, dict) and std_p_output.get('type') == 'elements_only': num_elements_only_precursors +=1
|
| 105 |
+
precursor_comp_feats_list.append(generate_compositional_features(std_p_output, DF_ELEMENTS_PROCESSED_GLOBAL, "predict_prec_comp"))
|
| 106 |
+
|
| 107 |
+
feature_dict['num_valid_precursors'] = num_valid_precursors
|
| 108 |
+
feature_dict['all_prec_are_stoichiometric'] = (num_stoich_precursors == num_valid_precursors) if num_valid_precursors > 0 else False
|
| 109 |
+
feature_dict['any_prec_is_elements_only'] = (num_elements_only_precursors > 0) if num_valid_precursors > 0 else False
|
| 110 |
+
|
| 111 |
+
if precursor_comp_feats_list:
|
| 112 |
+
df_prec_feats = pd.DataFrame(precursor_comp_feats_list)
|
| 113 |
+
numeric_cols_df_prec = df_prec_feats.select_dtypes(include=np.number)
|
| 114 |
+
if not numeric_cols_df_prec.empty:
|
| 115 |
+
temp_sample_df = pd.DataFrame([generate_compositional_features("H2O", DF_ELEMENTS_PROCESSED_GLOBAL)])
|
| 116 |
+
numeric_sample_comp_keys = [k for k in temp_sample_df.columns if pd.api.types.is_numeric_dtype(temp_sample_df[k]) and k not in ['is_stoichiometric_formula']]
|
| 117 |
+
for agg_func_name in ['mean', 'std', 'min', 'max', 'sum']:
|
| 118 |
+
aggregated_vals = getattr(numeric_cols_df_prec, agg_func_name)()
|
| 119 |
+
for feat_name_suffix in numeric_sample_comp_keys:
|
| 120 |
+
agg_feat_key = f"{agg_func_name}_prec_{feat_name_suffix}"
|
| 121 |
+
if agg_feat_key in feature_dict and feat_name_suffix in aggregated_vals:
|
| 122 |
+
feature_dict[agg_feat_key] = aggregated_vals[feat_name_suffix]
|
| 123 |
+
|
| 124 |
+
# Process Features
|
| 125 |
+
process_input_ops_list = raw_synthesis_input.get('operations_simplified_list', [])
|
| 126 |
+
all_atm_cats = list(set([col.split('ops_atm_cat_')[-1] for col in expected_feature_cols if col.startswith('ops_atm_cat_')]))
|
| 127 |
+
all_mix_meths = list(set([col.split('ops_mix_meth_')[-1] for col in expected_feature_cols if col.startswith('ops_mix_meth_')]))
|
| 128 |
+
proc_feats_generated = generate_process_features_for_input(process_input_ops_list, all_atm_cats, all_mix_meths)
|
| 129 |
+
for k, v in proc_feats_generated.items():
|
| 130 |
+
if k in feature_dict: feature_dict[k] = v
|
| 131 |
+
|
| 132 |
+
# Stoichiometry features
|
| 133 |
+
reactants_simplified = raw_synthesis_input.get('reactants_simplified', [])
|
| 134 |
+
products_simplified = raw_synthesis_input.get('products_simplified', [])
|
| 135 |
+
stoich_feats_generated = generate_stoichiometry_features_for_input(reactants_simplified, products_simplified, standardize_chemical_formula)
|
| 136 |
+
for k, v in stoich_feats_generated.items():
|
| 137 |
+
if k in feature_dict: feature_dict[k] = v
|
| 138 |
+
|
| 139 |
+
feature_vector_df = pd.DataFrame([feature_dict], columns=expected_feature_cols)
|
| 140 |
+
|
| 141 |
+
# Impute and Scale
|
| 142 |
+
imputer = ESSENTIAL_OBJECTS["imputers"].get(model_target_name)
|
| 143 |
+
scaler = ESSENTIAL_OBJECTS["scalers"].get(model_target_name)
|
| 144 |
+
|
| 145 |
+
numerical_features_for_transform = [col for col in expected_feature_cols if col in feature_vector_df.columns and pd.api.types.is_numeric_dtype(feature_vector_df[col].dtype) and not col.startswith('ops_') and not col.startswith('proc_has_') and not col.startswith('elem_block_') and col not in ['is_stoichiometric_formula', 'all_prec_are_stoichiometric', 'any_prec_is_elements_only', 'num_valid_precursors']]
|
| 146 |
+
|
| 147 |
+
if imputer and scaler and numerical_features_for_transform:
|
| 148 |
+
try:
|
| 149 |
+
feature_vector_df[numerical_features_for_transform] = feature_vector_df[numerical_features_for_transform].astype(np.float64)
|
| 150 |
+
feature_vector_df[numerical_features_for_transform] = imputer.transform(feature_vector_df[numerical_features_for_transform])
|
| 151 |
+
feature_vector_df[numerical_features_for_transform] = scaler.transform(feature_vector_df[numerical_features_for_transform])
|
| 152 |
+
logging.info("Feature vector imputed and scaled for prediction.")
|
| 153 |
+
except Exception as e_transform:
|
| 154 |
+
logging.error(f"Error during imputation/scaling for prediction: {e_transform}", exc_info=True)
|
| 155 |
+
return None
|
| 156 |
+
else:
|
| 157 |
+
logging.warning("Imputer, Scaler or numerical features missing for prediction. Proceeding with caution.")
|
| 158 |
+
return feature_vector_df
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def predict_synthesis_outcome(raw_synthesis_input):
|
| 162 |
+
global ESSENTIAL_OBJECTS
|
| 163 |
+
if not ESSENTIAL_OBJECTS.get("loaded_successfully"):
|
| 164 |
+
success = load_all_artifacts_once()
|
| 165 |
+
if not success:
|
| 166 |
+
logging.error("Essential artifacts could not be loaded. Cannot make predictions.")
|
| 167 |
+
return {}
|
| 168 |
+
|
| 169 |
+
predictions = {}
|
| 170 |
+
model_types_to_predict = ["temperature_bin", "atmosphere_category"]
|
| 171 |
+
|
| 172 |
+
for model_type in model_types_to_predict:
|
| 173 |
+
if ESSENTIAL_OBJECTS["models"].get(model_type):
|
| 174 |
+
logging.info(f"\n--- Predicting {model_type} ---")
|
| 175 |
+
feature_vector = create_feature_vector_for_prediction(raw_synthesis_input, model_type)
|
| 176 |
+
|
| 177 |
+
if feature_vector is not None:
|
| 178 |
+
model = ESSENTIAL_OBJECTS["models"][model_type]
|
| 179 |
+
encoder = ESSENTIAL_OBJECTS["encoders"][model_type]
|
| 180 |
+
try:
|
| 181 |
+
pred_encoded = model.predict(feature_vector)
|
| 182 |
+
pred_proba = model.predict_proba(feature_vector)
|
| 183 |
+
pred_label = encoder.inverse_transform(pred_encoded)[0]
|
| 184 |
+
|
| 185 |
+
predictions[model_type] = {
|
| 186 |
+
'predicted_label': pred_label,
|
| 187 |
+
'probabilities': {str(cls): prob for cls, prob in zip(encoder.classes_, pred_proba[0])}
|
| 188 |
+
}
|
| 189 |
+
logging.info(f"Predicted {model_type}: {pred_label}")
|
| 190 |
+
logging.info(f"Probabilities: {predictions[model_type]['probabilities']}")
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logging.error(f"Error during {model_type} prediction: {e}", exc_info=True)
|
| 193 |
+
predictions[model_type] = f"Prediction Error: {e}"
|
| 194 |
+
else:
|
| 195 |
+
logging.error(f"Could not create feature vector for {model_type} model.")
|
| 196 |
+
predictions[model_type] = "Feature vector creation error"
|
| 197 |
+
else:
|
| 198 |
+
logging.warning(f"{model_type} model not available for prediction.")
|
| 199 |
+
|
| 200 |
+
return predictions
|
| 201 |
+
|
| 202 |
+
if __name__ == '__main__':
|
| 203 |
+
# This block is for testing this inference script directly.
|
| 204 |
+
|
| 205 |
+
# Ensure artifacts are loaded
|
| 206 |
+
if not load_all_artifacts_once():
|
| 207 |
+
print("Exiting due to failure in loading essential artifacts.")
|
| 208 |
+
else:
|
| 209 |
+
print("\n--- Example Interactive Prediction ---")
|
| 210 |
+
example_input_with_ops_list = {
|
| 211 |
+
'target_formula_raw': "YBa2Cu3O7",
|
| 212 |
+
'precursor_formulas_raw': ["Y2O3", "BaCO3", "CuO"],
|
| 213 |
+
'operations_simplified_list': [
|
| 214 |
+
{'type': 'MixingOperation', 'string': 'Mix precursors by ball milling for 4h', 'conditions': {'duration': [{'value':4, 'unit':'h'}]}},
|
| 215 |
+
{'type': 'HeatingOperation', 'string': 'Calcined at 900C for 12h in air', 'conditions': {'heating_temperature': [{'value':900, 'unit':'C'}], 'heating_time': [{'value':12, 'unit':'h'}], 'atmosphere': 'Air'}},
|
| 216 |
+
{'type': 'HeatingOperation', 'string': 'Sintered at 950C for 24h in O2', 'conditions': {'heating_temperature': [{'value':950, 'unit':'C'}], 'heating_time': [{'value':20, 'unit':'h'}], 'atmosphere': 'Oxygen'}}
|
| 217 |
+
],
|
| 218 |
+
'reactants_simplified': [{'material': 'Y2O3', 'amount': 0.5}, {'material':'BaCO3', 'amount': 2.0}, {'material':'CuO', 'amount': 3.0}],
|
| 219 |
+
'products_simplified': [{'material':'YBa2Cu3O7', 'amount': 1.0}]
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
predictions = predict_synthesis_outcome(example_input_with_ops_list)
|
| 223 |
+
print(f"\nFinal Predictions for example input: {predictions}")
|
| 224 |
+
|
src/process_feature_utils.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import re
|
| 5 |
+
from .constants import ATMOSPHERE_CONFIG, MIXING_METHOD_CONFIG
|
| 6 |
+
|
| 7 |
+
def _extract_numerical_value_from_op_condition(condition_entry, target_keys=['value', 'max_value', 'values']):
|
| 8 |
+
if isinstance(condition_entry, list) and condition_entry:
|
| 9 |
+
if isinstance(condition_entry[0], dict):
|
| 10 |
+
for key in target_keys:
|
| 11 |
+
val = condition_entry[0].get(key)
|
| 12 |
+
if val is not None:
|
| 13 |
+
if isinstance(val, list) and val:
|
| 14 |
+
try: return float(val[0])
|
| 15 |
+
except: continue
|
| 16 |
+
try: return float(val)
|
| 17 |
+
except: continue
|
| 18 |
+
elif isinstance(condition_entry[0], (int, float, np.number)):
|
| 19 |
+
try: return float(condition_entry[0])
|
| 20 |
+
except: pass
|
| 21 |
+
elif isinstance(condition_entry, dict):
|
| 22 |
+
for key in target_keys:
|
| 23 |
+
val = condition_entry.get(key)
|
| 24 |
+
if val is not None:
|
| 25 |
+
if isinstance(val, list) and val:
|
| 26 |
+
try: return float(val[0])
|
| 27 |
+
except: continue
|
| 28 |
+
try: return float(val)
|
| 29 |
+
except: continue
|
| 30 |
+
elif isinstance(condition_entry, (int, float, np.number)):
|
| 31 |
+
try: return float(condition_entry)
|
| 32 |
+
except: pass
|
| 33 |
+
return np.nan
|
| 34 |
+
|
| 35 |
+
def _extract_atmosphere_from_op(op_conditions_dict, op_string, atm_config_local, entry_identifier):
|
| 36 |
+
atm_specific = atm_config_local["default_specific"]
|
| 37 |
+
atm_category = atm_config_local["default_category"]
|
| 38 |
+
found_atm = False
|
| 39 |
+
if isinstance(op_conditions_dict, dict):
|
| 40 |
+
atm_source_key_val = op_conditions_dict.get('atmosphere')
|
| 41 |
+
if not atm_source_key_val and 'text' in op_conditions_dict and isinstance(op_conditions_dict['text'], str) :
|
| 42 |
+
atm_source_key_val = op_conditions_dict['text']
|
| 43 |
+
if atm_source_key_val:
|
| 44 |
+
atm_str_to_parse = None
|
| 45 |
+
if isinstance(atm_source_key_val, list) and atm_source_key_val: atm_str_to_parse = str(atm_source_key_val[0])
|
| 46 |
+
elif isinstance(atm_source_key_val, str): atm_str_to_parse = atm_source_key_val
|
| 47 |
+
elif isinstance(atm_source_key_val, dict): atm_str_to_parse = str(atm_source_key_val.get('gas', atm_source_key_val.get('value', '')))
|
| 48 |
+
if atm_str_to_parse:
|
| 49 |
+
atm_str_lower = atm_str_to_parse.lower()
|
| 50 |
+
for pattern_regex, specific, category in atm_config_local["patterns"]:
|
| 51 |
+
if specific.lower() == atm_str_lower or re.search(pattern_regex, atm_str_to_parse, re.IGNORECASE):
|
| 52 |
+
atm_specific, atm_category, found_atm = specific, category, True; break
|
| 53 |
+
if not found_atm and '/' in atm_str_lower: atm_specific, atm_category, found_atm = atm_str_to_parse, "Mixed", True
|
| 54 |
+
if not found_atm and isinstance(op_string, str) and op_string:
|
| 55 |
+
for pattern_regex, specific, category in atm_config_local["patterns"]:
|
| 56 |
+
if re.search(pattern_regex, op_string, re.IGNORECASE):
|
| 57 |
+
atm_specific, atm_category, found_atm = specific, category, True; break
|
| 58 |
+
return atm_specific, atm_category
|
| 59 |
+
|
| 60 |
+
def _extract_mixing_method_from_op(op_dict, op_string, mix_config_local, entry_identifier):
|
| 61 |
+
mix_method = mix_config_local["default_method"]
|
| 62 |
+
op_type = str(op_dict.get('type', '')).lower()
|
| 63 |
+
if isinstance(op_string, str) and op_string:
|
| 64 |
+
for pattern_regex, method_name in mix_config_local["patterns"]:
|
| 65 |
+
if re.search(pattern_regex, op_string, re.IGNORECASE): return method_name
|
| 66 |
+
if 'mix' in op_type or 'grind' in op_type or 'mill' in op_type:
|
| 67 |
+
for pattern_regex, method_name in mix_config_local["patterns"]:
|
| 68 |
+
if re.search(pattern_regex, op_type, re.IGNORECASE): return method_name
|
| 69 |
+
if op_type.strip() and op_type not in ["mixing", "liquidgrinding", "solutionmixing", "grinding"]: return op_type
|
| 70 |
+
return mix_method
|
| 71 |
+
|
| 72 |
+
def _extract_thermal_conditions(conditions_dict, op_string, entry_identifier):
|
| 73 |
+
temps, durs = [], []
|
| 74 |
+
if isinstance(conditions_dict, dict):
|
| 75 |
+
temp_data = conditions_dict.get('heating_temperature')
|
| 76 |
+
if temp_data:
|
| 77 |
+
extracted_temp = _extract_numerical_value_from_op_condition(temp_data)
|
| 78 |
+
if pd.notna(extracted_temp): temps.append(extracted_temp)
|
| 79 |
+
dur_data = conditions_dict.get('heating_time')
|
| 80 |
+
if dur_data:
|
| 81 |
+
extracted_dur = _extract_numerical_value_from_op_condition(dur_data)
|
| 82 |
+
if pd.notna(extracted_dur): durs.append(extracted_dur)
|
| 83 |
+
return temps, durs
|
| 84 |
+
|
| 85 |
+
def parse_single_operation_detailed_for_input(op_dict_raw, entry_identifier="predict_op"):
|
| 86 |
+
if not isinstance(op_dict_raw, dict): return {}
|
| 87 |
+
op_type_lower = str(op_dict_raw.get('type', 'UnknownType')).lower()
|
| 88 |
+
op_string_lower = str(op_dict_raw.get('string', '')).lower()
|
| 89 |
+
conditions = op_dict_raw.get('conditions', {})
|
| 90 |
+
op_features = {}
|
| 91 |
+
temps, durs = _extract_thermal_conditions(conditions, op_string_lower, entry_identifier)
|
| 92 |
+
op_features['op_temp_C_list'], op_features['op_duration_h_list'] = temps, durs
|
| 93 |
+
op_features['op_atmosphere_specific'], op_features['op_atmosphere_category'] = _extract_atmosphere_from_op(conditions, op_string_lower, ATMOSPHERE_CONFIG, entry_identifier)
|
| 94 |
+
op_features['op_mixing_method'] = _extract_mixing_method_from_op(op_dict_raw, op_string_lower, MIXING_METHOD_CONFIG, entry_identifier)
|
| 95 |
+
op_features['op_is_heating'] = any(k in op_type_lower for k in ['heat', 'anneal', 'sinter', 'calcination'])
|
| 96 |
+
op_features['op_is_mixing'] = 'mix' in op_type_lower or op_features['op_mixing_method'] != MIXING_METHOD_CONFIG["default_method"]
|
| 97 |
+
op_features['op_is_grinding'] = any(k in op_type_lower for k in ['grind', 'mill']) or 'pulverize' in op_string_lower or op_features['op_mixing_method'] in ['grinding', 'ball_milling', 'planetary_milling', 'attritor_milling', 'shaker_milling', 'mortar_pestle']
|
| 98 |
+
op_features['op_is_shaping'] = 'shap' in op_type_lower
|
| 99 |
+
op_features['op_is_drying'] = 'dry' in op_type_lower or 'drying' in op_type_lower
|
| 100 |
+
op_features['op_is_quenching'] = 'quench' in op_type_lower
|
| 101 |
+
op_features['op_is_annealing'] = 'anneal' in op_type_lower or 'anneal' in op_string_lower
|
| 102 |
+
op_features['op_is_sintering'] = 'sinter' in op_type_lower or 'sinter' in op_string_lower
|
| 103 |
+
op_features['op_is_calcination'] = any(k in op_type_lower for k in ['calcine', 'calcination']) or 'calcination' in op_string_lower
|
| 104 |
+
return op_features
|
| 105 |
+
|
| 106 |
+
def generate_process_features_for_input(operations_simplified_list, all_possible_atm_categories, all_possible_mix_methods):
|
| 107 |
+
aggregated_ops_features = {
|
| 108 |
+
'proc_total_heating_duration_h': 0.0, 'proc_max_temperature_C': np.nan,
|
| 109 |
+
'proc_min_temperature_C': np.nan, 'proc_avg_temperature_C': np.nan,
|
| 110 |
+
'proc_primary_heating_temp_C': np.nan,
|
| 111 |
+
'proc_num_total_steps': 0, 'proc_num_heating_steps': 0,
|
| 112 |
+
'proc_num_mixing_steps': 0, 'proc_num_grinding_steps': 0,
|
| 113 |
+
'proc_has_annealing': False, 'proc_has_sintering': False,
|
| 114 |
+
'proc_has_calcination': False, 'proc_has_quenching': False,
|
| 115 |
+
'proc_has_shaping': False, 'proc_has_drying': False,
|
| 116 |
+
}
|
| 117 |
+
for cat in all_possible_atm_categories: aggregated_ops_features[f"ops_atm_cat_{cat}"] = 0
|
| 118 |
+
for meth in all_possible_mix_methods: aggregated_ops_features[f"ops_mix_meth_{meth}"] = 0
|
| 119 |
+
|
| 120 |
+
if not isinstance(operations_simplified_list, list): operations_simplified_list = []
|
| 121 |
+
aggregated_ops_features['proc_num_total_steps'] = len(operations_simplified_list)
|
| 122 |
+
all_temps_in_reaction, heating_steps_details_for_reaction, mixing_methods_found_in_reaction = [], [], []
|
| 123 |
+
atm_set_for_reaction_flag = False
|
| 124 |
+
parsed_atm_category_for_input = ATMOSPHERE_CONFIG["default_category"]
|
| 125 |
+
parsed_mix_method_for_input = MIXING_METHOD_CONFIG["default_method"]
|
| 126 |
+
|
| 127 |
+
for op_idx, op_dict_raw in enumerate(operations_simplified_list):
|
| 128 |
+
op_features = parse_single_operation_detailed_for_input(op_dict_raw, f"predict_op_{op_idx}")
|
| 129 |
+
if op_features.get('op_temp_C_list'): all_temps_in_reaction.extend(op_features['op_temp_C_list'])
|
| 130 |
+
if op_features.get('op_is_heating'):
|
| 131 |
+
aggregated_ops_features['proc_num_heating_steps'] += 1
|
| 132 |
+
if op_features.get('op_duration_h_list'): aggregated_ops_features['proc_total_heating_duration_h'] += np.nansum(op_features['op_duration_h_list'])
|
| 133 |
+
heating_steps_details_for_reaction.append({'temp': np.nanmax(op_features['op_temp_C_list']) if op_features.get('op_temp_C_list') and len(op_features['op_temp_C_list']) > 0 else np.nan,
|
| 134 |
+
'duration': np.nansum(op_features.get('op_duration_h_list', [0.0])),
|
| 135 |
+
'atm_category': op_features.get('op_atmosphere_category'),
|
| 136 |
+
'is_anneal': op_features.get('op_is_annealing'), 'is_sinter': op_features.get('op_is_sintering'), 'is_calcine': op_features.get('op_is_calcination')})
|
| 137 |
+
if op_features.get('op_is_mixing'):
|
| 138 |
+
aggregated_ops_features['proc_num_mixing_steps'] += 1
|
| 139 |
+
current_mix_method = op_features.get('op_mixing_method', MIXING_METHOD_CONFIG["default_method"])
|
| 140 |
+
if current_mix_method != MIXING_METHOD_CONFIG["default_method"]: mixing_methods_found_in_reaction.append(current_mix_method)
|
| 141 |
+
if op_features.get('op_is_grinding'): aggregated_ops_features['proc_num_grinding_steps'] += 1
|
| 142 |
+
if op_features.get('op_is_shaping'): aggregated_ops_features['proc_has_shaping'] = True
|
| 143 |
+
if op_features.get('op_is_sintering'): aggregated_ops_features['proc_has_sintering'] = True
|
| 144 |
+
if op_features.get('op_is_drying'): aggregated_ops_features['proc_has_drying'] = True
|
| 145 |
+
if op_features.get('op_is_quenching'): aggregated_ops_features['proc_has_quenching'] = True
|
| 146 |
+
if op_features.get('op_is_annealing'): aggregated_ops_features['proc_has_annealing'] = True
|
| 147 |
+
if op_features.get('op_is_calcination'): aggregated_ops_features['proc_has_calcination'] = True
|
| 148 |
+
if not atm_set_for_reaction_flag and op_features.get('op_atmosphere_category') != ATMOSPHERE_CONFIG["default_category"]:
|
| 149 |
+
parsed_atm_category_for_input = op_features['op_atmosphere_category']
|
| 150 |
+
atm_set_for_reaction_flag = True
|
| 151 |
+
|
| 152 |
+
if heating_steps_details_for_reaction:
|
| 153 |
+
primary_heat_step = max(heating_steps_details_for_reaction, key=lambda x: (x['temp'] if pd.notna(x['temp']) else -float('inf'), x['duration']))
|
| 154 |
+
if pd.notna(primary_heat_step['temp']): aggregated_ops_features['proc_primary_heating_temp_C'] = primary_heat_step['temp']
|
| 155 |
+
if not atm_set_for_reaction_flag and primary_heat_step.get('atm_category') != ATMOSPHERE_CONFIG["default_category"]:
|
| 156 |
+
parsed_atm_category_for_input = primary_heat_step['atm_category']
|
| 157 |
+
|
| 158 |
+
if mixing_methods_found_in_reaction:
|
| 159 |
+
parsed_mix_method_for_input = mixing_methods_found_in_reaction[0]
|
| 160 |
+
|
| 161 |
+
atm_ohe_col = f"ops_atm_cat_{parsed_atm_category_for_input}"
|
| 162 |
+
if atm_ohe_col in aggregated_ops_features: aggregated_ops_features[atm_ohe_col] = 1
|
| 163 |
+
|
| 164 |
+
mix_ohe_col = f"ops_mix_meth_{parsed_mix_method_for_input}"
|
| 165 |
+
if mix_ohe_col in aggregated_ops_features: aggregated_ops_features[mix_ohe_col] = 1
|
| 166 |
+
|
| 167 |
+
if all_temps_in_reaction :
|
| 168 |
+
aggregated_ops_features['proc_max_temperature_C'] = np.nanmax(all_temps_in_reaction)
|
| 169 |
+
aggregated_ops_features['proc_min_temperature_C'] = np.nanmin(all_temps_in_reaction)
|
| 170 |
+
aggregated_ops_features['proc_avg_temperature_C'] = np.nanmean(all_temps_in_reaction)
|
| 171 |
+
if aggregated_ops_features['proc_num_heating_steps'] == 0 or pd.isna(aggregated_ops_features['proc_total_heating_duration_h']) or aggregated_ops_features['proc_total_heating_duration_h'] == 0:
|
| 172 |
+
aggregated_ops_features['proc_total_heating_duration_h'] = np.nan
|
| 173 |
+
|
| 174 |
+
return aggregated_ops_features
|
| 175 |
+
|
| 176 |
+
def generate_stoichiometry_features_for_input(reactants_simplified, products_simplified, standardize_fn_local):
|
| 177 |
+
stoich_features = {}
|
| 178 |
+
max_r, max_p = 3, 2
|
| 179 |
+
for i in range(max_r): stoich_features[f'reactant{i+1}_coeff'] = np.nan
|
| 180 |
+
for i in range(max_p): stoich_features[f'product{i+1}_coeff'] = np.nan
|
| 181 |
+
|
| 182 |
+
stoich_features['num_reactants_in_reaction'] = len(reactants_simplified) if reactants_simplified else 0
|
| 183 |
+
if reactants_simplified:
|
| 184 |
+
for i, r_item in enumerate(reactants_simplified[:max_r]):
|
| 185 |
+
if isinstance(r_item, dict):
|
| 186 |
+
stoich_features[f'reactant{i+1}_coeff'] = float(r_item.get('amount')) if pd.notna(r_item.get('amount')) else np.nan
|
| 187 |
+
stoich_features['num_products_in_reaction'] = len(products_simplified) if products_simplified else 0
|
| 188 |
+
if products_simplified:
|
| 189 |
+
for i, p_item in enumerate(products_simplified[:max_p]):
|
| 190 |
+
if isinstance(p_item, dict):
|
| 191 |
+
stoich_features[f'product{i+1}_coeff'] = float(p_item.get('amount')) if pd.notna(p_item.get('amount')) else np.nan
|
| 192 |
+
return stoich_features
|