File size: 13,244 Bytes
3961ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

import pandas as pd
import numpy as np
import joblib
import os
import logging
from pymatgen.core import Composition 
import re 

from .constants import KNOWN_ELEMENT_SYMBOLS, ATMOSPHERE_CONFIG, MIXING_METHOD_CONFIG, MAGPIE_FEATURIZER, MAGPIE_LABELS, matminer_available
from .feature_engineering_utils import standardize_chemical_formula, generate_compositional_features
from .process_feature_utils import generate_process_features_for_input, generate_stoichiometry_features_for_input

MODEL_DIR = "../models" 
PREPROCESSOR_DIR = "../models" 
ELEMENTAL_DATA_PATH = os.path.join(MODEL_DIR, "df_elements_processed.pkl")

ESSENTIAL_OBJECTS = {}
DF_ELEMENTS_PROCESSED_GLOBAL = None

def load_all_artifacts_once():
    global DF_ELEMENTS_PROCESSED_GLOBAL, ESSENTIAL_OBJECTS, matminer_available, MAGPIE_FEATURIZER, MAGPIE_LABELS 
    if ESSENTIAL_OBJECTS.get("loaded_successfully"):
        logging.info("Artifacts already loaded.")
        return True

    logging.info("--- Loading Essential Artifacts for Prediction ---")
    script_dir = os.path.dirname(__file__) 
    
    try:
        elemental_data_full_path = os.path.join(script_dir, ELEMENTAL_DATA_PATH)
        DF_ELEMENTS_PROCESSED_GLOBAL = pd.read_pickle(elemental_data_full_path)
        ESSENTIAL_OBJECTS["elemental_data"] = DF_ELEMENTS_PROCESSED_GLOBAL
        logging.info(f"Loaded processed elemental data from {elemental_data_full_path}")
    except Exception as e:
        logging.critical(f"CRITICAL: Error loading elemental data from {elemental_data_full_path}: {e}")
        return False
    
    if not matminer_available: # Attempt to re-init if constants.py didn't catch it
        try:
            from matminer.featurizers.composition import ElementProperty
            MAGPIE_FEATURIZER = ElementProperty.from_preset("magpie", impute_nan=True)
            MAGPIE_LABELS = [f'magpie_{label.replace(" ", "_")}' for label in MAGPIE_FEATURIZER.feature_labels()]
            matminer_available = True
            logging.info("Matminer re-initialized in inference script.")
        except:
            logging.warning("Matminer could not be re-initialized in inference script.")


    ESSENTIAL_OBJECTS["models"] = {}
    ESSENTIAL_OBJECTS["encoders"] = {}
    ESSENTIAL_OBJECTS["imputers"] = {}
    ESSENTIAL_OBJECTS["scalers"] = {}
    ESSENTIAL_OBJECTS["feature_columns"] = {}

    all_loaded_successfully = True
    for model_type_key in ["temperature_bin", "atmosphere_category"]:
        model_artifact_name = f"{model_type_key}_tuned"
        try:
            ESSENTIAL_OBJECTS["models"][model_type_key] = joblib.load(os.path.join(script_dir, MODEL_DIR, f"{model_artifact_name}_lgbm_model.joblib"))
            ESSENTIAL_OBJECTS["encoders"][model_type_key] = joblib.load(os.path.join(script_dir, MODEL_DIR, f"{model_artifact_name}_label_encoder.joblib"))
            ESSENTIAL_OBJECTS["imputers"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_imputer.joblib"))
            ESSENTIAL_OBJECTS["scalers"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_scaler.joblib"))
            ESSENTIAL_OBJECTS["feature_columns"][model_type_key] = joblib.load(os.path.join(script_dir, PREPROCESSOR_DIR, f"{model_artifact_name}_feature_columns.joblib"))
            logging.info(f"Loaded artifacts for {model_artifact_name} model.")
        except Exception as e:
            logging.error(f"Error loading one or more artifacts for '{model_artifact_name}': {e}. Predictions for it may fail.")
            ESSENTIAL_OBJECTS["models"][model_type_key] = None
            all_loaded_successfully = False
            
    ESSENTIAL_OBJECTS["loaded_successfully"] = all_loaded_successfully
    return all_loaded_successfully

def create_feature_vector_for_prediction(raw_synthesis_input, model_target_name):
    global DF_ELEMENTS_PROCESSED_GLOBAL, ESSENTIAL_OBJECTS
    
    if DF_ELEMENTS_PROCESSED_GLOBAL is None:
        logging.error("Elemental data not loaded. Call load_all_artifacts_once() first.")
        return None

    expected_feature_cols = ESSENTIAL_OBJECTS["feature_columns"].get(model_target_name)
    if not expected_feature_cols:
        logging.error(f"Feature column list for '{model_target_name}' not found in loaded artifacts.")
        return None
    
    feature_dict = {col: (0 if col.startswith(("ops_", "proc_has_", "elem_block_")) or "is_stoichiometric" in col or "is_elements_only" in col else np.nan) for col in expected_feature_cols}

    # Target Compositional Features
    std_target_output = standardize_chemical_formula(raw_synthesis_input.get('target_formula_raw'), "predict_target")
    target_comp_feats = generate_compositional_features(std_target_output, DF_ELEMENTS_PROCESSED_GLOBAL, "predict_target_comp")
    for k, v in target_comp_feats.items():
        feature_key = f'target_{k}'
        if feature_key in feature_dict: feature_dict[feature_key] = v

    # Precursor Compositional Features
    precursor_formulas_raw = raw_synthesis_input.get('precursor_formulas_raw', [])
    std_precursors_outputs = [standardize_chemical_formula(p, f"predict_prec_{i}") for i, p in enumerate(precursor_formulas_raw)]
    num_valid_precursors, num_stoich_precursors, num_elements_only_precursors = 0,0,0
    precursor_comp_feats_list = []
    for std_p_output in std_precursors_outputs:
        if std_p_output is not None:
            num_valid_precursors += 1
            if isinstance(std_p_output, str): num_stoich_precursors += 1
            elif isinstance(std_p_output, dict) and std_p_output.get('type') == 'elements_only': num_elements_only_precursors +=1
        precursor_comp_feats_list.append(generate_compositional_features(std_p_output, DF_ELEMENTS_PROCESSED_GLOBAL, "predict_prec_comp"))
    
    feature_dict['num_valid_precursors'] = num_valid_precursors
    feature_dict['all_prec_are_stoichiometric'] = (num_stoich_precursors == num_valid_precursors) if num_valid_precursors > 0 else False
    feature_dict['any_prec_is_elements_only'] = (num_elements_only_precursors > 0) if num_valid_precursors > 0 else False

    if precursor_comp_feats_list:
        df_prec_feats = pd.DataFrame(precursor_comp_feats_list)
        numeric_cols_df_prec = df_prec_feats.select_dtypes(include=np.number)
        if not numeric_cols_df_prec.empty:
            temp_sample_df = pd.DataFrame([generate_compositional_features("H2O", DF_ELEMENTS_PROCESSED_GLOBAL)])
            numeric_sample_comp_keys = [k for k in temp_sample_df.columns if pd.api.types.is_numeric_dtype(temp_sample_df[k]) and k not in ['is_stoichiometric_formula']]
            for agg_func_name in ['mean', 'std', 'min', 'max', 'sum']:
                aggregated_vals = getattr(numeric_cols_df_prec, agg_func_name)()
                for feat_name_suffix in numeric_sample_comp_keys:
                    agg_feat_key = f"{agg_func_name}_prec_{feat_name_suffix}"
                    if agg_feat_key in feature_dict and feat_name_suffix in aggregated_vals:
                        feature_dict[agg_feat_key] = aggregated_vals[feat_name_suffix]
    
    # Process Features
    process_input_ops_list = raw_synthesis_input.get('operations_simplified_list', []) 
    all_atm_cats = list(set([col.split('ops_atm_cat_')[-1] for col in expected_feature_cols if col.startswith('ops_atm_cat_')]))
    all_mix_meths = list(set([col.split('ops_mix_meth_')[-1] for col in expected_feature_cols if col.startswith('ops_mix_meth_')]))
    proc_feats_generated = generate_process_features_for_input(process_input_ops_list, all_atm_cats, all_mix_meths)
    for k, v in proc_feats_generated.items():
        if k in feature_dict: feature_dict[k] = v
            
    # Stoichiometry features
    reactants_simplified = raw_synthesis_input.get('reactants_simplified', []) 
    products_simplified = raw_synthesis_input.get('products_simplified', [])
    stoich_feats_generated = generate_stoichiometry_features_for_input(reactants_simplified, products_simplified, standardize_chemical_formula)
    for k, v in stoich_feats_generated.items():
        if k in feature_dict: feature_dict[k] = v

    feature_vector_df = pd.DataFrame([feature_dict], columns=expected_feature_cols) 
    
    # Impute and Scale
    imputer = ESSENTIAL_OBJECTS["imputers"].get(model_target_name)
    scaler = ESSENTIAL_OBJECTS["scalers"].get(model_target_name)
    
    numerical_features_for_transform = [col for col in expected_feature_cols if col in feature_vector_df.columns and pd.api.types.is_numeric_dtype(feature_vector_df[col].dtype) and not col.startswith('ops_') and not col.startswith('proc_has_') and not col.startswith('elem_block_') and col not in ['is_stoichiometric_formula', 'all_prec_are_stoichiometric', 'any_prec_is_elements_only', 'num_valid_precursors']]

    if imputer and scaler and numerical_features_for_transform:
        try:
            feature_vector_df[numerical_features_for_transform] = feature_vector_df[numerical_features_for_transform].astype(np.float64)
            feature_vector_df[numerical_features_for_transform] = imputer.transform(feature_vector_df[numerical_features_for_transform])
            feature_vector_df[numerical_features_for_transform] = scaler.transform(feature_vector_df[numerical_features_for_transform])
            logging.info("Feature vector imputed and scaled for prediction.")
        except Exception as e_transform:
            logging.error(f"Error during imputation/scaling for prediction: {e_transform}", exc_info=True)
            return None
    else:
        logging.warning("Imputer, Scaler or numerical features missing for prediction. Proceeding with caution.")
    return feature_vector_df


def predict_synthesis_outcome(raw_synthesis_input):
    global ESSENTIAL_OBJECTS
    if not ESSENTIAL_OBJECTS.get("loaded_successfully"):
        success = load_all_artifacts_once()
        if not success:
            logging.error("Essential artifacts could not be loaded. Cannot make predictions.")
            return {}
        
    predictions = {}
    model_types_to_predict = ["temperature_bin", "atmosphere_category"]

    for model_type in model_types_to_predict:
        if ESSENTIAL_OBJECTS["models"].get(model_type):
            logging.info(f"\n--- Predicting {model_type} ---")
            feature_vector = create_feature_vector_for_prediction(raw_synthesis_input, model_type)
            
            if feature_vector is not None:
                model = ESSENTIAL_OBJECTS["models"][model_type]
                encoder = ESSENTIAL_OBJECTS["encoders"][model_type]
                try:
                    pred_encoded = model.predict(feature_vector)
                    pred_proba = model.predict_proba(feature_vector)
                    pred_label = encoder.inverse_transform(pred_encoded)[0]
                    
                    predictions[model_type] = {
                        'predicted_label': pred_label,
                        'probabilities': {str(cls): prob for cls, prob in zip(encoder.classes_, pred_proba[0])}
                    }
                    logging.info(f"Predicted {model_type}: {pred_label}")
                    logging.info(f"Probabilities: {predictions[model_type]['probabilities']}")
                except Exception as e:
                    logging.error(f"Error during {model_type} prediction: {e}", exc_info=True)
                    predictions[model_type] = f"Prediction Error: {e}"
            else:
                logging.error(f"Could not create feature vector for {model_type} model.")
                predictions[model_type] = "Feature vector creation error"
        else:
            logging.warning(f"{model_type} model not available for prediction.")
            
    return predictions

if __name__ == '__main__':
    # This block is for testing this inference script directly.
    
    # Ensure artifacts are loaded
    if not load_all_artifacts_once():
        print("Exiting due to failure in loading essential artifacts.")
    else:
        print("\n--- Example Interactive Prediction ---")
        example_input_with_ops_list = {
            'target_formula_raw': "YBa2Cu3O7",
            'precursor_formulas_raw': ["Y2O3", "BaCO3", "CuO"],
            'operations_simplified_list': [
                {'type': 'MixingOperation', 'string': 'Mix precursors by ball milling for 4h', 'conditions': {'duration': [{'value':4, 'unit':'h'}]}},
                {'type': 'HeatingOperation', 'string': 'Calcined at 900C for 12h in air', 'conditions': {'heating_temperature': [{'value':900, 'unit':'C'}], 'heating_time': [{'value':12, 'unit':'h'}], 'atmosphere': 'Air'}},
                {'type': 'HeatingOperation', 'string': 'Sintered at 950C for 24h in O2', 'conditions': {'heating_temperature': [{'value':950, 'unit':'C'}], 'heating_time': [{'value':20, 'unit':'h'}], 'atmosphere': 'Oxygen'}}
            ],
            'reactants_simplified': [{'material': 'Y2O3', 'amount': 0.5}, {'material':'BaCO3', 'amount': 2.0}, {'material':'CuO', 'amount': 3.0}],
            'products_simplified': [{'material':'YBa2Cu3O7', 'amount': 1.0}]
        }

        predictions = predict_synthesis_outcome(example_input_with_ops_list)
        print(f"\nFinal Predictions for example input: {predictions}")