| import argparse |
| import os |
| import sys |
| import numpy as np |
| import pandas as pd |
| import torch |
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
| from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier |
| from sklearn.model_selection import cross_validate |
|
|
| |
| DEFAULT_TASKS = ['ESOL', 'FreeSolv', 'HIV', 'BACE', 'BBBP', 'ClinTox'] |
| MODEL_NAME = "DeepChem/ChemBERTa-10M-MLM" |
|
|
| def load_model_and_checkpoint(checkpoint_path, device="cpu"): |
| print(f"Loading model {MODEL_NAME}...", file=sys.stderr) |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| backbone = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).roberta.to(device) |
|
|
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"File not found: {checkpoint_path}") |
|
|
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| backbone.load_state_dict(checkpoint['backbone']) |
| backbone.eval() |
| print("Model is loaded", file=sys.stderr) |
| return tokenizer, backbone |
|
|
| @torch.no_grad() |
| def mol_to_emb(smiles, tokenizer, model, device="cpu"): |
| tokenized = tokenizer([smiles], padding=False, return_tensors="pt") |
| input_ids = tokenized['input_ids'].to(device) |
| hs = model(input_ids).last_hidden_state |
|
|
| emb = torch.cat([hs[:, 0], hs[:, 1:].mean(dim=1)], dim=1) |
| return emb.squeeze(0).cpu().numpy() |
|
|
| def evaluate_tasks(checkpoint_path, data_dir='./support/', device="cpu"): |
| tasks = DEFAULT_TASKS |
| tokenizer, model = load_model_and_checkpoint(checkpoint_path, device) |
|
|
| results = {} |
| for task in tasks: |
| csv_path = os.path.join(data_dir, f"{task}.csv") |
| if not os.path.exists(csv_path): |
| print(f"\n[WARN] File {csv_path} not found. Skipping '{task}'.", file=sys.stderr) |
| continue |
|
|
| print(f"Task: {task}", file=sys.stderr) |
| ds = pd.read_csv(csv_path, sep='\t') |
|
|
| |
| ds['v'] = ds['X'].apply(lambda x: mol_to_emb(x, tokenizer, model, device)) |
| ds = ds.sample(frac=1, random_state=42).reset_index(drop=True) |
|
|
| |
| X = np.stack(ds['v'].values) |
| y = ds['y'].to_numpy() |
|
|
| |
| if task in ['ESOL', 'FreeSolv']: |
| rf_model = RandomForestRegressor(random_state=42, n_jobs=5) |
| scoring = 'neg_mean_absolute_error' |
| metric_name = "MAE" |
| else: |
| rf_model = RandomForestClassifier(random_state=42, n_jobs=5) |
| scoring = 'f1_macro' |
| metric_name = "F1-macro" |
|
|
| |
| cv_results = cross_validate(rf_model, X, y, cv=5, scoring=scoring, n_jobs=1) |
| mean_score = cv_results['test_score'].mean() |
| std_score = cv_results['test_score'].std() |
| results[task] = (mean_score, std_score) |
| print(f" {metric_name}: {mean_score:.4f} ± {std_score:.4f}", file=sys.stderr) |
|
|
|
|
| for task, (mean, std) in results.items(): |
| print(f"{task:10}: {mean:.4f} ± {std:.4f}") |
| return results |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "checkpoint_path", type=str, |
| help="Path to checkpoint file (.pth)" |
| ) |
| parser.add_argument( |
| "--device", type=str, default="cpu", choices=["cpu", "cuda"], |
| ) |
| args = parser.parse_args() |
|
|
| evaluate_tasks( |
| checkpoint_path=args.checkpoint_path, |
| device=args.device |
| ) |
|
|