import os os.environ['TF_KERAS'] = '1' os.environ['CUDA_LAUNCH_BLOCKING'] = "1" import torch import numpy as np import pandas as pd from rdkit import Chem from rdkit.Chem import Descriptors, rdMolDescriptors import pickle import json from typing import Dict, List, Any import os import deepchem as dc from tqdm import tqdm import math from MY_GNN.inference import eval_on_host_train from NIPS_GNN.inference import predict as nips_predict class DeepChemTqdmCallback: """ DeepChem-style callback: called as callback(model, current_step). Shows a per-epoch tqdm bar (updates once per batch). """ def __init__(self, dataset, batch_size, leave=False): self.dataset = dataset self.batch_size = int(batch_size) self.leave = leave # Try to infer dataset length try: self.n = len(dataset) except Exception: y = getattr(dataset, "y", None) if hasattr(y, "shape"): self.n = int(y.shape[0]) else: self.n = None self.steps = None if self.n is None else math.ceil(self.n / self.batch_size) self.pbar = None self.last_epoch = -1 def __call__(self, model, current_step): """ Called by DeepChem as callback(model, current_step) after each batch. current_step is an integer (global batch count). """ # ensure int step = int(current_step) # If we can't infer steps_per_epoch, show an indeterminate progress spinner if self.steps is None: if self.pbar is None: self.pbar = tqdm(total=None, desc=f"Step {step}", leave=self.leave) else: self.pbar.update(1) return # Determine epoch and batch-within-epoch epoch = step // self.steps batch_in_epoch = step % self.steps # If new epoch, close previous bar and open a new one if epoch != self.last_epoch: if self.pbar is not None: try: self.pbar.close() except Exception: pass self.pbar = tqdm(total=self.steps, desc=f"Epoch {epoch+1}", leave=self.leave) self.last_epoch = epoch # Update bar to current batch (handles possible non-1 step jumps) # Usually first call in epoch will have batch_in_epoch == 0 -> update by 1 self.pbar.update(batch_in_epoch + 1) return # Same epoch: advance by 1 (typical case) if self.pbar is not None: self.pbar.update(1) def close(self): """Call after training to ensure bar closed.""" if self.pbar is not None: try: self.pbar.close() except Exception: pass self.pbar = None class PolymerPropertyPredictor: def __init__(self): self.models = {} self.load_models() def load_models(self): # Load different models for different properties self.models = { 'rg': "MY_GNN/trained_models/rg", 'density': "MY_GNN/trained_models/density", 'ffv': "NIPS_GNN/trained_models", 'tg': "NIPS_GNN/trained_models", 'tc': f"DA_GNN/trained_models/2 layers", } def predict_rg_density(self, SMILES: str) -> Dict[str, float]: """Predict Rg and Density using ensemble of models""" output_df = pd.DataFrame(columns=["SMILES", "Density", "Rg"]) output_df["SMILES"] = [SMILES] for label in ["Density", "Rg"]: host_csv = pd.DataFrame([SMILES], columns=['SMILES']) model_dir = self.models[label.lower()] preds, allp = eval_on_host_train( label, host_csv, model_pattern=f"{model_dir}/model_{label}_fold*.pt", desc_cols_file=f"{model_dir}/desc_cols_{label}.pkl", evaluate=False, ) output_df[label] = preds return { 'rg': float(output_df["Rg"].values[0]) if isinstance(output_df["Rg"].values[0], (float, int)) else None, 'density': float(output_df["Density"].values[0]) if isinstance(output_df["Density"].values[0], (float, int)) else None } def predict_ffv_tg(self, SMILES: str) -> Dict[str, float]: """Predict FFV and Tg using molecular fingerprint approach""" pred_df = pd.DataFrame(columns=["SMILES", "FFV", "Tg"]) pred_df["SMILES"] = [SMILES] for target in ['FFV', 'Tg']: dict_path = f'{self.models[target.lower()]}/{target.lower()}_dictionaries.pkl' smiles_list = [ SMILES ] model_path = f'{self.models[target.lower()]}/{target.lower()}_model.pt' # Using trained models predictions = nips_predict(smiles_list, target, model_path, dict_path) for i, pred in enumerate(predictions): if pred is not None: pred_df[target] = pred[0][0][0] else: pred_df[target] = None return { 'ffv': float(pred_df["FFV"].values[0]) if isinstance(pred_df["FFV"].values[0], (float, int)) else None, 'tg': float(pred_df["Tg"].values[0]) if isinstance(pred_df["Tg"].values[0], (float, int)) else None } def predict_tc(self, SMILES: str) -> float | None: """Predict Tc using data augmentation model""" # Apply your specific preprocessing for Tc Restore_MODEL_DIR = self.models['tc'] smiles = [ SMILES ] smiles_df = pd.DataFrame(smiles, columns=['SMILES']) # Featurizerization print("# Featurizerization -> ", end="") featurizer = dc.feat.ConvMolFeaturizer() smiles_list = smiles_df['SMILES'].tolist() molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list] featurized_mols = featurizer.featurize(molecules) testset = dc.data.NumpyDataset(X=featurized_mols) val_pred = [] print("Predicting -> ", end="") for i in range(5): MODEL_DIR = Restore_MODEL_DIR + '/' + 'loop' + str(i + 1) model = dc.models.GraphConvModel(1, mode="regression", model_dir=MODEL_DIR) model.restore() # Predict val_pred.append(model.predict(testset)) print("Done") val_pred = sum(val_pred) / len(val_pred) return float(val_pred[0]) if isinstance(val_pred[0], (float, int)) else None def predict_all_properties(self, smiles: str) -> Dict[str, float]: """Main prediction function""" try: # Step 2: Property-specific predictions rg_density = self.predict_rg_density(smiles) ffv_tg = self.predict_ffv_tg(smiles) tc = self.predict_tc(smiles) # Step 3: Combine all predictions predictions = { 'rg': rg_density['rg'], 'density': rg_density['density'], 'ffv': ffv_tg['ffv'], 'tg': ffv_tg['tg'], 'tc': tc } return predictions except Exception as e: raise ValueError(f"Prediction Error: {str(e)}") # Global predictor instance predictor = None def load_model(): """Load the model (called once when container starts)""" global predictor if predictor is None: predictor = PolymerPropertyPredictor() return predictor def predict(inputs): """Main prediction function for HuggingFace""" try: # Load model if not loaded model = load_model() # Handle different input formats if isinstance(inputs, str): smiles = inputs elif isinstance(inputs, dict): smiles = inputs.get('inputs', inputs.get('smiles', '')) elif isinstance(inputs, list) and len(inputs) > 0: smiles = inputs[0] if isinstance(inputs[0], str) else inputs[0].get('inputs', '') else: raise ValueError("Invalid input format") # Make prediction predictions = model.predict_all_properties(smiles) # Format output result = { 'smiles': smiles, 'predictions': predictions, 'properties': { 'Tg (Glass Transition Temperature)': f"{predictions['tg']:.2f} °C", 'Tc (Crystallization Temperature)': f"{predictions['tc']:.2f} °C", 'FFV (Fractional Free Volume)': f"{predictions['ffv']:.4f}", 'Density': f"{predictions['density']:.3f} g/cm³", 'Rg (Radius of Gyration)': f"{predictions['rg']:.2f} Å" } } return result except Exception as e: return {"error": str(e)}