chemberta-3-phinformed / test_model.py
timcryt's picture
Initial commit
f6fc460 verified
import argparse
import os
import sys
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import cross_validate
# Конфигурация по умолчанию
DEFAULT_TASKS = ['ESOL', 'FreeSolv', 'HIV', 'BACE', 'BBBP', 'ClinTox']
MODEL_NAME = "DeepChem/ChemBERTa-10M-MLM"
def load_model_and_checkpoint(checkpoint_path, device="cpu"):
print(f"Loading model {MODEL_NAME}...", file=sys.stderr)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
backbone = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).roberta.to(device)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"File not found: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
backbone.load_state_dict(checkpoint['backbone'])
backbone.eval()
print("Model is loaded", file=sys.stderr)
return tokenizer, backbone
@torch.no_grad()
def mol_to_emb(smiles, tokenizer, model, device="cpu"):
tokenized = tokenizer([smiles], padding=False, return_tensors="pt")
input_ids = tokenized['input_ids'].to(device)
hs = model(input_ids).last_hidden_state
emb = torch.cat([hs[:, 0], hs[:, 1:].mean(dim=1)], dim=1)
return emb.squeeze(0).cpu().numpy()
def evaluate_tasks(checkpoint_path, data_dir='./support/', device="cpu"):
tasks = DEFAULT_TASKS
tokenizer, model = load_model_and_checkpoint(checkpoint_path, device)
results = {}
for task in tasks:
csv_path = os.path.join(data_dir, f"{task}.csv")
if not os.path.exists(csv_path):
print(f"\n[WARN] File {csv_path} not found. Skipping '{task}'.", file=sys.stderr)
continue
print(f"Task: {task}", file=sys.stderr)
ds = pd.read_csv(csv_path, sep='\t')
# Вычисление эмбеддингов с прогресс-баром
ds['v'] = ds['X'].apply(lambda x: mol_to_emb(x, tokenizer, model, device))
ds = ds.sample(frac=1, random_state=42).reset_index(drop=True)
# Подготовка данных для sklearn
X = np.stack(ds['v'].values)
y = ds['y'].to_numpy()
# Выбор модели и метрики
if task in ['ESOL', 'FreeSolv']:
rf_model = RandomForestRegressor(random_state=42, n_jobs=5)
scoring = 'neg_mean_absolute_error'
metric_name = "MAE"
else:
rf_model = RandomForestClassifier(random_state=42, n_jobs=5)
scoring = 'f1_macro'
metric_name = "F1-macro"
# Кросс-валидация
cv_results = cross_validate(rf_model, X, y, cv=5, scoring=scoring, n_jobs=1)
mean_score = cv_results['test_score'].mean()
std_score = cv_results['test_score'].std()
results[task] = (mean_score, std_score)
print(f" {metric_name}: {mean_score:.4f} ± {std_score:.4f}", file=sys.stderr)
for task, (mean, std) in results.items():
print(f"{task:10}: {mean:.4f} ± {std:.4f}")
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_path", type=str,
help="Path to checkpoint file (.pth)"
)
parser.add_argument(
"--device", type=str, default="cpu", choices=["cpu", "cuda"],
)
args = parser.parse_args()
evaluate_tasks(
checkpoint_path=args.checkpoint_path,
device=args.device
)