Spaces:
Sleeping
Sleeping
Mdasif45 commited on
Commit ·
e620469
1
Parent(s): 316e7a7
Initial FastAPI deployment
Browse files- Dockerfile +13 -0
- main.py +105 -0
- predict_all.py +25 -0
- predict_band_gap.py +185 -0
- predict_e_above_hull.py +185 -0
- predict_epa.py +184 -0
- predict_fepa.py +185 -0
- predict_is_gap_direct.py +114 -0
- predict_volume.py +185 -0
- requirements.txt +13 -0
Dockerfile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY requirements.txt .
|
| 6 |
+
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
COPY . .
|
| 10 |
+
|
| 11 |
+
EXPOSE 7860
|
| 12 |
+
|
| 13 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
main.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Callable, Dict, Optional
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
|
| 11 |
+
THRESHOLD = 0.33
|
| 12 |
+
predict_all_fn: Optional[Callable[..., Dict[str, object]]] = None
|
| 13 |
+
PROJECT_DIR = Path(__file__).resolve().parent.parent
|
| 14 |
+
HF_REPO_ID = "asif45/LLM-PROP"
|
| 15 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 16 |
+
CHECKPOINT_FILES = {
|
| 17 |
+
"checkpoints/samples/classification/best_checkpoint_for_is_gap_direct.pt": "checkpoints/samples/classification/best_checkpoint_for_is_gap_direct.pt",
|
| 18 |
+
"checkpoints/samples/regression/best_checkpoint_for_band_gap.pt": "checkpoints/samples/regression/best_checkpoint_for_band_gap.pt",
|
| 19 |
+
"checkpoints/samples/regression/best_checkpoint_for_energy_per_atom.pt": "checkpoints/samples/regression/best_checkpoint_for_energy_per_atom.pt",
|
| 20 |
+
"checkpoints/samples/regression/best_checkpoint_for_e_above_hull.pt": "checkpoints/samples/regression/best_checkpoint_for_e_above_hull.pt",
|
| 21 |
+
"checkpoints/samples/regression/best_checkpoint_for_fepa.pt": "checkpoints/samples/regression/best_checkpoint_for_fepa.pt",
|
| 22 |
+
"checkpoints/samples/regression/best_checkpoint_for_volume.pt": "checkpoints/samples/regression/best_checkpoint_for_volume.pt",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PredictRequest(BaseModel):
|
| 27 |
+
text: str = Field(..., description="Crystal description text")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PredictResponse(BaseModel):
|
| 31 |
+
is_gap_direct: str
|
| 32 |
+
energy_per_atom: float
|
| 33 |
+
formation_energy_per_atom: float
|
| 34 |
+
band_gap: float
|
| 35 |
+
e_above_hull: float
|
| 36 |
+
volume: float
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
app = FastAPI(title="Crystal Property Predictor API")
|
| 40 |
+
|
| 41 |
+
app.add_middleware(
|
| 42 |
+
CORSMiddleware,
|
| 43 |
+
allow_origins=[
|
| 44 |
+
"http://localhost:8080",
|
| 45 |
+
"http://127.0.0.1:8080",
|
| 46 |
+
"http://localhost:5173",
|
| 47 |
+
"http://127.0.0.1:5173",
|
| 48 |
+
],
|
| 49 |
+
allow_credentials=True,
|
| 50 |
+
allow_methods=["*"],
|
| 51 |
+
allow_headers=["*"],
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def ensure_checkpoint_files() -> None:
|
| 56 |
+
for local_relative_path, repo_file_path in CHECKPOINT_FILES.items():
|
| 57 |
+
local_path = PROJECT_DIR / local_relative_path
|
| 58 |
+
if local_path.exists():
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
downloaded_path = hf_hub_download(
|
| 63 |
+
repo_id=HF_REPO_ID,
|
| 64 |
+
filename=repo_file_path,
|
| 65 |
+
repo_type="model",
|
| 66 |
+
token=HF_TOKEN,
|
| 67 |
+
)
|
| 68 |
+
shutil.copy2(downloaded_path, local_path)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@app.on_event("startup")
|
| 72 |
+
def load_model_once() -> None:
|
| 73 |
+
# Download missing checkpoints first, then import the predictor so it loads the local files once.
|
| 74 |
+
ensure_checkpoint_files()
|
| 75 |
+
|
| 76 |
+
global predict_all_fn
|
| 77 |
+
from predict_all import predict_all
|
| 78 |
+
|
| 79 |
+
predict_all_fn = predict_all
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@app.get("/health")
|
| 83 |
+
def health() -> Dict[str, object]:
|
| 84 |
+
return {"status": "ok", "model_loaded": predict_all_fn is not None}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@app.post("/predict", response_model=PredictResponse)
|
| 88 |
+
def predict(payload: PredictRequest) -> PredictResponse:
|
| 89 |
+
if predict_all_fn is None:
|
| 90 |
+
raise HTTPException(status_code=503, detail="Model is not loaded yet")
|
| 91 |
+
|
| 92 |
+
text = payload.text.strip()
|
| 93 |
+
if not text:
|
| 94 |
+
raise HTTPException(status_code=400, detail="Text input cannot be empty")
|
| 95 |
+
|
| 96 |
+
predictions = predict_all_fn(text, threshold=THRESHOLD)
|
| 97 |
+
filtered_predictions = {
|
| 98 |
+
"is_gap_direct": predictions["is_gap_direct"],
|
| 99 |
+
"energy_per_atom": predictions["energy_per_atom"],
|
| 100 |
+
"formation_energy_per_atom": predictions["formation_energy_per_atom"],
|
| 101 |
+
"band_gap": predictions["band_gap"],
|
| 102 |
+
"e_above_hull": predictions["e_above_hull"],
|
| 103 |
+
"volume": predictions["volume"],
|
| 104 |
+
}
|
| 105 |
+
return PredictResponse(**filtered_predictions)
|
predict_all.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from predict_is_gap_direct import predict
|
| 2 |
+
from predict_epa import predict_epa
|
| 3 |
+
from predict_fepa import predict_fepa
|
| 4 |
+
from predict_band_gap import predict_band_gap
|
| 5 |
+
from predict_e_above_hull import predict_e_above_hull
|
| 6 |
+
from predict_volume import predict_volume
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def predict_all(text, threshold=0.33, max_length=256):
|
| 10 |
+
is_gap_direct, confidence = predict(text, threshold=threshold)
|
| 11 |
+
energy_per_atom = predict_epa(text, max_length=max_length)
|
| 12 |
+
formation_energy_per_atom = predict_fepa(text, max_length=max_length)
|
| 13 |
+
band_gap = predict_band_gap(text, max_length=max_length)
|
| 14 |
+
e_above_hull = predict_e_above_hull(text, max_length=max_length)
|
| 15 |
+
volume = predict_volume(text, max_length=max_length)
|
| 16 |
+
|
| 17 |
+
return {
|
| 18 |
+
"is_gap_direct": is_gap_direct,
|
| 19 |
+
"confidence": confidence,
|
| 20 |
+
"energy_per_atom": energy_per_atom,
|
| 21 |
+
"formation_energy_per_atom": formation_energy_per_atom,
|
| 22 |
+
"band_gap": band_gap,
|
| 23 |
+
"e_above_hull": e_above_hull,
|
| 24 |
+
"volume": volume,
|
| 25 |
+
}
|
predict_band_gap.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 11 |
+
from transformers.utils import logging as transformers_logging
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
PROJECT_CANDIDATES = [
|
| 15 |
+
SCRIPT_DIR,
|
| 16 |
+
os.path.dirname(SCRIPT_DIR),
|
| 17 |
+
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
PROJECT_DIR = None
|
| 21 |
+
for candidate in PROJECT_CANDIDATES:
|
| 22 |
+
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
|
| 23 |
+
PROJECT_DIR = candidate
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
if PROJECT_DIR is None:
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
"Could not locate project root containing llmprop_model.py. "
|
| 29 |
+
"Expected near the deployment folder."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
|
| 33 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 34 |
+
|
| 35 |
+
from llmprop_model import T5Predictor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def z_denormalize(scaled_labels, labels_mean, labels_std):
|
| 39 |
+
return (scaled_labels * labels_std) + labels_mean
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# -------------------------
|
| 43 |
+
# CONFIG
|
| 44 |
+
# -------------------------
|
| 45 |
+
MODEL_PATH = os.path.join(
|
| 46 |
+
PROJECT_DIR,
|
| 47 |
+
"checkpoints",
|
| 48 |
+
"samples",
|
| 49 |
+
"regression",
|
| 50 |
+
"best_checkpoint_for_band_gap.pt",
|
| 51 |
+
)
|
| 52 |
+
TOKENIZER_PATH = os.path.join(
|
| 53 |
+
PROJECT_DIR,
|
| 54 |
+
"tokenizers",
|
| 55 |
+
"t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
|
| 56 |
+
)
|
| 57 |
+
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
|
| 58 |
+
PROPERTY_NAME = "band_gap"
|
| 59 |
+
DEVICE = torch.device("cpu")
|
| 60 |
+
|
| 61 |
+
# Silence HF/Transformers startup logs for cleaner terminal output.
|
| 62 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 63 |
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
| 64 |
+
transformers_logging.set_verbosity_error()
|
| 65 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# -------------------------
|
| 69 |
+
# PATH CHECKS
|
| 70 |
+
# -------------------------
|
| 71 |
+
if not os.path.exists(MODEL_PATH):
|
| 72 |
+
raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
|
| 73 |
+
if not os.path.exists(TOKENIZER_PATH):
|
| 74 |
+
raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")
|
| 75 |
+
if not os.path.exists(TRAIN_DATA_PATH):
|
| 76 |
+
raise FileNotFoundError(f"Training data not found: {TRAIN_DATA_PATH}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -------------------------
|
| 80 |
+
# LOAD TRAIN LABEL STATS (z_norm)
|
| 81 |
+
# -------------------------
|
| 82 |
+
train_df = pd.read_csv(TRAIN_DATA_PATH)
|
| 83 |
+
if PROPERTY_NAME not in train_df.columns:
|
| 84 |
+
raise ValueError(f"Column '{PROPERTY_NAME}' not found in {TRAIN_DATA_PATH}")
|
| 85 |
+
|
| 86 |
+
train_labels = torch.tensor(
|
| 87 |
+
train_df[PROPERTY_NAME].dropna().to_numpy(),
|
| 88 |
+
dtype=torch.float32,
|
| 89 |
+
)
|
| 90 |
+
if train_labels.numel() == 0:
|
| 91 |
+
raise ValueError(f"No non-null values found for '{PROPERTY_NAME}' in {TRAIN_DATA_PATH}")
|
| 92 |
+
|
| 93 |
+
TRAIN_LABEL_MEAN = torch.mean(train_labels)
|
| 94 |
+
TRAIN_LABEL_STD = torch.std(train_labels)
|
| 95 |
+
if float(TRAIN_LABEL_STD) == 0.0:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Standard deviation for '{PROPERTY_NAME}' is 0.0; z_norm de-normalization is undefined"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _quiet_call(fn, *args, **kwargs):
|
| 102 |
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
| 103 |
+
return fn(*args, **kwargs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# -------------------------
|
| 107 |
+
# LOAD TOKENIZER
|
| 108 |
+
# -------------------------
|
| 109 |
+
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# -------------------------
|
| 113 |
+
# LOAD MODEL
|
| 114 |
+
# -------------------------
|
| 115 |
+
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
|
| 116 |
+
base_model_output_size = 512
|
| 117 |
+
|
| 118 |
+
# Match embedding matrix size to the tokenizer used during training.
|
| 119 |
+
base_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
| 120 |
+
|
| 121 |
+
model = T5Predictor(
|
| 122 |
+
base_model,
|
| 123 |
+
base_model_output_size,
|
| 124 |
+
drop_rate=0.1,
|
| 125 |
+
pooling="mean",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# -------------------------
|
| 130 |
+
# LOAD WEIGHTS
|
| 131 |
+
# -------------------------
|
| 132 |
+
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)
|
| 133 |
+
|
| 134 |
+
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
|
| 135 |
+
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
|
| 136 |
+
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
|
| 137 |
+
model.model.resize_token_embeddings(checkpoint_vocab_size, mean_resizing=False)
|
| 138 |
+
|
| 139 |
+
model.load_state_dict(state_dict, strict=False)
|
| 140 |
+
model.to(DEVICE)
|
| 141 |
+
model.eval()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# -------------------------
|
| 145 |
+
# PREDICT FUNCTION
|
| 146 |
+
# -------------------------
|
| 147 |
+
def predict_band_gap(text, max_length=256):
|
| 148 |
+
inputs = tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
return_tensors="pt",
|
| 151 |
+
truncation=True,
|
| 152 |
+
padding=True,
|
| 153 |
+
max_length=max_length,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 157 |
+
attention_mask = inputs["attention_mask"].to(DEVICE)
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
_, prediction_norm = model(input_ids, attention_mask)
|
| 161 |
+
prediction_band_gap = z_denormalize(
|
| 162 |
+
prediction_norm.squeeze(),
|
| 163 |
+
TRAIN_LABEL_MEAN,
|
| 164 |
+
TRAIN_LABEL_STD,
|
| 165 |
+
).item()
|
| 166 |
+
|
| 167 |
+
return prediction_band_gap
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -------------------------
|
| 171 |
+
# TEST
|
| 172 |
+
# -------------------------
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
parser = argparse.ArgumentParser(description="Predict band_gap from text")
|
| 175 |
+
parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--text",
|
| 178 |
+
type=str,
|
| 179 |
+
default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li–O bond distances ranging from 1.98–2.25 Å. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo–O bond distances ranging from 1.74–2.46 Å. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15–44°. There are a spread of Mo–O bond distances ranging from 1.77–1.82 Å. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al–O bond distances ranging from 1.88–1.95 Å. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
|
| 180 |
+
help="Input text to predict band gap",
|
| 181 |
+
)
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
value = predict_band_gap(args.text, max_length=args.max_length)
|
| 185 |
+
print(f"Predicted band_gap: {value:.6f}")
|
predict_e_above_hull.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 11 |
+
from transformers.utils import logging as transformers_logging
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
PROJECT_CANDIDATES = [
|
| 15 |
+
SCRIPT_DIR,
|
| 16 |
+
os.path.dirname(SCRIPT_DIR),
|
| 17 |
+
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
PROJECT_DIR = None
|
| 21 |
+
for candidate in PROJECT_CANDIDATES:
|
| 22 |
+
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
|
| 23 |
+
PROJECT_DIR = candidate
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
if PROJECT_DIR is None:
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
"Could not locate project root containing llmprop_model.py. "
|
| 29 |
+
"Expected near the deployment folder."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
|
| 33 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 34 |
+
|
| 35 |
+
from llmprop_model import T5Predictor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def z_denormalize(scaled_labels, labels_mean, labels_std):
|
| 39 |
+
return (scaled_labels * labels_std) + labels_mean
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# -------------------------
|
| 43 |
+
# CONFIG
|
| 44 |
+
# -------------------------
|
| 45 |
+
MODEL_PATH = os.path.join(
|
| 46 |
+
PROJECT_DIR,
|
| 47 |
+
"checkpoints",
|
| 48 |
+
"samples",
|
| 49 |
+
"regression",
|
| 50 |
+
"best_checkpoint_for_e_above_hull.pt",
|
| 51 |
+
)
|
| 52 |
+
TOKENIZER_PATH = os.path.join(
|
| 53 |
+
PROJECT_DIR,
|
| 54 |
+
"tokenizers",
|
| 55 |
+
"t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
|
| 56 |
+
)
|
| 57 |
+
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
|
| 58 |
+
PROPERTY_NAME = "e_above_hull"
|
| 59 |
+
DEVICE = torch.device("cpu")
|
| 60 |
+
|
| 61 |
+
# Silence HF/Transformers startup logs for cleaner terminal output.
|
| 62 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 63 |
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
| 64 |
+
transformers_logging.set_verbosity_error()
|
| 65 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# -------------------------
|
| 69 |
+
# PATH CHECKS
|
| 70 |
+
# -------------------------
|
| 71 |
+
if not os.path.exists(MODEL_PATH):
|
| 72 |
+
raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
|
| 73 |
+
if not os.path.exists(TOKENIZER_PATH):
|
| 74 |
+
raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")
|
| 75 |
+
if not os.path.exists(TRAIN_DATA_PATH):
|
| 76 |
+
raise FileNotFoundError(f"Training data not found: {TRAIN_DATA_PATH}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -------------------------
|
| 80 |
+
# LOAD TRAIN LABEL STATS (z_norm)
|
| 81 |
+
# -------------------------
|
| 82 |
+
train_df = pd.read_csv(TRAIN_DATA_PATH)
|
| 83 |
+
if PROPERTY_NAME not in train_df.columns:
|
| 84 |
+
raise ValueError(f"Column '{PROPERTY_NAME}' not found in {TRAIN_DATA_PATH}")
|
| 85 |
+
|
| 86 |
+
train_labels = torch.tensor(
|
| 87 |
+
train_df[PROPERTY_NAME].dropna().to_numpy(),
|
| 88 |
+
dtype=torch.float32,
|
| 89 |
+
)
|
| 90 |
+
if train_labels.numel() == 0:
|
| 91 |
+
raise ValueError(f"No non-null values found for '{PROPERTY_NAME}' in {TRAIN_DATA_PATH}")
|
| 92 |
+
|
| 93 |
+
TRAIN_LABEL_MEAN = torch.mean(train_labels)
|
| 94 |
+
TRAIN_LABEL_STD = torch.std(train_labels)
|
| 95 |
+
if float(TRAIN_LABEL_STD) == 0.0:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Standard deviation for '{PROPERTY_NAME}' is 0.0; z_norm de-normalization is undefined"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _quiet_call(fn, *args, **kwargs):
|
| 102 |
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
| 103 |
+
return fn(*args, **kwargs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# -------------------------
|
| 107 |
+
# LOAD TOKENIZER
|
| 108 |
+
# -------------------------
|
| 109 |
+
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# -------------------------
|
| 113 |
+
# LOAD MODEL
|
| 114 |
+
# -------------------------
|
| 115 |
+
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
|
| 116 |
+
base_model_output_size = 512
|
| 117 |
+
|
| 118 |
+
# Match embedding matrix size to the tokenizer used during training.
|
| 119 |
+
base_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
| 120 |
+
|
| 121 |
+
model = T5Predictor(
|
| 122 |
+
base_model,
|
| 123 |
+
base_model_output_size,
|
| 124 |
+
drop_rate=0.1,
|
| 125 |
+
pooling="mean",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# -------------------------
|
| 130 |
+
# LOAD WEIGHTS
|
| 131 |
+
# -------------------------
|
| 132 |
+
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)
|
| 133 |
+
|
| 134 |
+
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
|
| 135 |
+
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
|
| 136 |
+
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
|
| 137 |
+
model.model.resize_token_embeddings(checkpoint_vocab_size, mean_resizing=False)
|
| 138 |
+
|
| 139 |
+
model.load_state_dict(state_dict, strict=False)
|
| 140 |
+
model.to(DEVICE)
|
| 141 |
+
model.eval()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# -------------------------
|
| 145 |
+
# PREDICT FUNCTION
|
| 146 |
+
# -------------------------
|
| 147 |
+
def predict_e_above_hull(text, max_length=256):
|
| 148 |
+
inputs = tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
return_tensors="pt",
|
| 151 |
+
truncation=True,
|
| 152 |
+
padding=True,
|
| 153 |
+
max_length=max_length,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 157 |
+
attention_mask = inputs["attention_mask"].to(DEVICE)
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
_, prediction_norm = model(input_ids, attention_mask)
|
| 161 |
+
prediction_e_above_hull = z_denormalize(
|
| 162 |
+
prediction_norm.squeeze(),
|
| 163 |
+
TRAIN_LABEL_MEAN,
|
| 164 |
+
TRAIN_LABEL_STD,
|
| 165 |
+
).item()
|
| 166 |
+
|
| 167 |
+
return prediction_e_above_hull
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -------------------------
|
| 171 |
+
# TEST
|
| 172 |
+
# -------------------------
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
parser = argparse.ArgumentParser(description="Predict e_above_hull from text")
|
| 175 |
+
parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--text",
|
| 178 |
+
type=str,
|
| 179 |
+
default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li-O bond distances ranging from 1.98-2.25 A. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo-O bond distances ranging from 1.74-2.46 A. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15-44 degrees. There are a spread of Mo-O bond distances ranging from 1.77-1.82 A. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al-O bond distances ranging from 1.88-1.95 A. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
|
| 180 |
+
help="Input text to predict e_above_hull",
|
| 181 |
+
)
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
value = predict_e_above_hull(args.text, max_length=args.max_length)
|
| 185 |
+
print(f"Predicted e_above_hull: {value:.6f}")
|
predict_epa.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import logging
|
| 5 |
+
import io
|
| 6 |
+
import contextlib
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 11 |
+
from transformers.utils import logging as transformers_logging
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
PROJECT_CANDIDATES = [
|
| 15 |
+
SCRIPT_DIR,
|
| 16 |
+
os.path.dirname(SCRIPT_DIR),
|
| 17 |
+
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
PROJECT_DIR = None
|
| 21 |
+
for candidate in PROJECT_CANDIDATES:
|
| 22 |
+
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
|
| 23 |
+
PROJECT_DIR = candidate
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
if PROJECT_DIR is None:
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
"Could not locate project root containing llmprop_model.py. "
|
| 29 |
+
"Expected near the deployment folder."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
|
| 33 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 34 |
+
|
| 35 |
+
from llmprop_model import T5Predictor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def z_denormalize(scaled_labels, labels_mean, labels_std):
|
| 39 |
+
return (scaled_labels * labels_std) + labels_mean
|
| 40 |
+
|
| 41 |
+
# -------------------------
|
| 42 |
+
# CONFIG
|
| 43 |
+
# -------------------------
|
| 44 |
+
MODEL_PATH = os.path.join(
|
| 45 |
+
PROJECT_DIR,
|
| 46 |
+
"checkpoints",
|
| 47 |
+
"samples",
|
| 48 |
+
"regression",
|
| 49 |
+
"best_checkpoint_for_energy_per_atom.pt",
|
| 50 |
+
)
|
| 51 |
+
TOKENIZER_PATH = os.path.join(
|
| 52 |
+
PROJECT_DIR,
|
| 53 |
+
"tokenizers",
|
| 54 |
+
"t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
|
| 55 |
+
)
|
| 56 |
+
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
|
| 57 |
+
PROPERTY_NAME = "energy_per_atom"
|
| 58 |
+
DEVICE = torch.device("cpu")
|
| 59 |
+
|
| 60 |
+
# Silence HF/Transformers startup logs for cleaner terminal output.
|
| 61 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 62 |
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
| 63 |
+
transformers_logging.set_verbosity_error()
|
| 64 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# -------------------------
|
| 68 |
+
# PATH CHECKS
|
| 69 |
+
# -------------------------
|
| 70 |
+
if not os.path.exists(MODEL_PATH):
|
| 71 |
+
raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
|
| 72 |
+
if not os.path.exists(TOKENIZER_PATH):
|
| 73 |
+
raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")
|
| 74 |
+
if not os.path.exists(TRAIN_DATA_PATH):
|
| 75 |
+
raise FileNotFoundError(f"Training data not found: {TRAIN_DATA_PATH}")
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# -------------------------
|
| 79 |
+
# LOAD TRAIN LABEL STATS (z_norm)
|
| 80 |
+
# -------------------------
|
| 81 |
+
train_df = pd.read_csv(TRAIN_DATA_PATH)
|
| 82 |
+
if PROPERTY_NAME not in train_df.columns:
|
| 83 |
+
raise ValueError(f"Column '{PROPERTY_NAME}' not found in {TRAIN_DATA_PATH}")
|
| 84 |
+
|
| 85 |
+
train_labels = torch.tensor(
|
| 86 |
+
train_df[PROPERTY_NAME].dropna().to_numpy(),
|
| 87 |
+
dtype=torch.float32,
|
| 88 |
+
)
|
| 89 |
+
if train_labels.numel() == 0:
|
| 90 |
+
raise ValueError(f"No non-null values found for '{PROPERTY_NAME}' in {TRAIN_DATA_PATH}")
|
| 91 |
+
|
| 92 |
+
TRAIN_LABEL_MEAN = torch.mean(train_labels)
|
| 93 |
+
TRAIN_LABEL_STD = torch.std(train_labels)
|
| 94 |
+
if float(TRAIN_LABEL_STD) == 0.0:
|
| 95 |
+
raise ValueError(
|
| 96 |
+
f"Standard deviation for '{PROPERTY_NAME}' is 0.0; z_norm de-normalization is undefined"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _quiet_call(fn, *args, **kwargs):
|
| 101 |
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
| 102 |
+
return fn(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# -------------------------
|
| 106 |
+
# LOAD TOKENIZER
|
| 107 |
+
# -------------------------
|
| 108 |
+
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# -------------------------
|
| 112 |
+
# LOAD MODEL
|
| 113 |
+
# -------------------------
|
| 114 |
+
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
|
| 115 |
+
base_model_output_size = 512
|
| 116 |
+
|
| 117 |
+
# Match embedding matrix size to the tokenizer used during training.
|
| 118 |
+
base_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
| 119 |
+
|
| 120 |
+
model = T5Predictor(
|
| 121 |
+
base_model,
|
| 122 |
+
base_model_output_size,
|
| 123 |
+
drop_rate=0.1,
|
| 124 |
+
pooling="mean",
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# -------------------------
|
| 129 |
+
# LOAD WEIGHTS
|
| 130 |
+
# -------------------------
|
| 131 |
+
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)
|
| 132 |
+
|
| 133 |
+
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
|
| 134 |
+
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
|
| 135 |
+
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
|
| 136 |
+
model.model.resize_token_embeddings(checkpoint_vocab_size, mean_resizing=False)
|
| 137 |
+
|
| 138 |
+
model.load_state_dict(state_dict, strict=False)
|
| 139 |
+
model.to(DEVICE)
|
| 140 |
+
model.eval()
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# -------------------------
|
| 144 |
+
# PREDICT FUNCTION
|
| 145 |
+
# -------------------------
|
| 146 |
+
def predict_epa(text, max_length=256):
|
| 147 |
+
inputs = tokenizer(
|
| 148 |
+
text,
|
| 149 |
+
return_tensors="pt",
|
| 150 |
+
truncation=True,
|
| 151 |
+
padding=True,
|
| 152 |
+
max_length=max_length,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 156 |
+
attention_mask = inputs["attention_mask"].to(DEVICE)
|
| 157 |
+
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
_, prediction_norm = model(input_ids, attention_mask)
|
| 160 |
+
prediction_epa = z_denormalize(
|
| 161 |
+
prediction_norm.squeeze(),
|
| 162 |
+
TRAIN_LABEL_MEAN,
|
| 163 |
+
TRAIN_LABEL_STD,
|
| 164 |
+
).item()
|
| 165 |
+
|
| 166 |
+
return prediction_epa
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# -------------------------
|
| 170 |
+
# TEST
|
| 171 |
+
# -------------------------
|
| 172 |
+
if __name__ == "__main__":
|
| 173 |
+
parser = argparse.ArgumentParser(description="Predict energy_per_atom from text")
|
| 174 |
+
parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--text",
|
| 177 |
+
type=str,
|
| 178 |
+
default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li–O bond distances ranging from 1.98–2.25 Å. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo–O bond distances ranging from 1.74–2.46 Å. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15–44°. There are a spread of Mo–O bond distances ranging from 1.77–1.82 Å. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al–O bond distances ranging from 1.88–1.95 Å. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
|
| 179 |
+
help="Input text to predict EPA",
|
| 180 |
+
)
|
| 181 |
+
args = parser.parse_args()
|
| 182 |
+
|
| 183 |
+
value = predict_epa(args.text, max_length=args.max_length)
|
| 184 |
+
print(f"Predicted energy_per_atom: {value:.6f}")
|
predict_fepa.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 11 |
+
from transformers.utils import logging as transformers_logging
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
PROJECT_CANDIDATES = [
|
| 15 |
+
SCRIPT_DIR,
|
| 16 |
+
os.path.dirname(SCRIPT_DIR),
|
| 17 |
+
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
PROJECT_DIR = None
|
| 21 |
+
for candidate in PROJECT_CANDIDATES:
|
| 22 |
+
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
|
| 23 |
+
PROJECT_DIR = candidate
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
if PROJECT_DIR is None:
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
"Could not locate project root containing llmprop_model.py. "
|
| 29 |
+
"Expected near the deployment folder."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
|
| 33 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 34 |
+
|
| 35 |
+
from llmprop_model import T5Predictor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def z_denormalize(scaled_labels, labels_mean, labels_std):
|
| 39 |
+
return (scaled_labels * labels_std) + labels_mean
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# -------------------------
|
| 43 |
+
# CONFIG
|
| 44 |
+
# -------------------------
|
| 45 |
+
MODEL_PATH = os.path.join(
|
| 46 |
+
PROJECT_DIR,
|
| 47 |
+
"checkpoints",
|
| 48 |
+
"samples",
|
| 49 |
+
"regression",
|
| 50 |
+
"best_checkpoint_for_fepa.pt",
|
| 51 |
+
)
|
| 52 |
+
TOKENIZER_PATH = os.path.join(
|
| 53 |
+
PROJECT_DIR,
|
| 54 |
+
"tokenizers",
|
| 55 |
+
"t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
|
| 56 |
+
)
|
| 57 |
+
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
|
| 58 |
+
PROPERTY_NAME = "formation_energy_per_atom"
|
| 59 |
+
DEVICE = torch.device("cpu")
|
| 60 |
+
|
| 61 |
+
# Silence HF/Transformers startup logs for cleaner terminal output.
|
| 62 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 63 |
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
| 64 |
+
transformers_logging.set_verbosity_error()
|
| 65 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# -------------------------
|
| 69 |
+
# PATH CHECKS
|
| 70 |
+
# -------------------------
|
| 71 |
+
if not os.path.exists(MODEL_PATH):
|
| 72 |
+
raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
|
| 73 |
+
if not os.path.exists(TOKENIZER_PATH):
|
| 74 |
+
raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")
|
| 75 |
+
if not os.path.exists(TRAIN_DATA_PATH):
|
| 76 |
+
raise FileNotFoundError(f"Training data not found: {TRAIN_DATA_PATH}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -------------------------
|
| 80 |
+
# LOAD TRAIN LABEL STATS (z_norm)
|
| 81 |
+
# -------------------------
|
| 82 |
+
train_df = pd.read_csv(TRAIN_DATA_PATH)
|
| 83 |
+
if PROPERTY_NAME not in train_df.columns:
|
| 84 |
+
raise ValueError(f"Column '{PROPERTY_NAME}' not found in {TRAIN_DATA_PATH}")
|
| 85 |
+
|
| 86 |
+
train_labels = torch.tensor(
|
| 87 |
+
train_df[PROPERTY_NAME].dropna().to_numpy(),
|
| 88 |
+
dtype=torch.float32,
|
| 89 |
+
)
|
| 90 |
+
if train_labels.numel() == 0:
|
| 91 |
+
raise ValueError(f"No non-null values found for '{PROPERTY_NAME}' in {TRAIN_DATA_PATH}")
|
| 92 |
+
|
| 93 |
+
TRAIN_LABEL_MEAN = torch.mean(train_labels)
|
| 94 |
+
TRAIN_LABEL_STD = torch.std(train_labels)
|
| 95 |
+
if float(TRAIN_LABEL_STD) == 0.0:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Standard deviation for '{PROPERTY_NAME}' is 0.0; z_norm de-normalization is undefined"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _quiet_call(fn, *args, **kwargs):
|
| 102 |
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
| 103 |
+
return fn(*args, **kwargs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# -------------------------
|
| 107 |
+
# LOAD TOKENIZER
|
| 108 |
+
# -------------------------
|
| 109 |
+
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# -------------------------
|
| 113 |
+
# LOAD MODEL
|
| 114 |
+
# -------------------------
|
| 115 |
+
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
|
| 116 |
+
base_model_output_size = 512
|
| 117 |
+
|
| 118 |
+
# Match embedding matrix size to the tokenizer used during training.
|
| 119 |
+
base_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
| 120 |
+
|
| 121 |
+
model = T5Predictor(
|
| 122 |
+
base_model,
|
| 123 |
+
base_model_output_size,
|
| 124 |
+
drop_rate=0.1,
|
| 125 |
+
pooling="mean",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# -------------------------
|
| 130 |
+
# LOAD WEIGHTS
|
| 131 |
+
# -------------------------
|
| 132 |
+
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)
|
| 133 |
+
|
| 134 |
+
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
|
| 135 |
+
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
|
| 136 |
+
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
|
| 137 |
+
model.model.resize_token_embeddings(checkpoint_vocab_size, mean_resizing=False)
|
| 138 |
+
|
| 139 |
+
model.load_state_dict(state_dict, strict=False)
|
| 140 |
+
model.to(DEVICE)
|
| 141 |
+
model.eval()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# -------------------------
|
| 145 |
+
# PREDICT FUNCTION
|
| 146 |
+
# -------------------------
|
| 147 |
+
def predict_fepa(text, max_length=256):
|
| 148 |
+
inputs = tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
return_tensors="pt",
|
| 151 |
+
truncation=True,
|
| 152 |
+
padding=True,
|
| 153 |
+
max_length=max_length,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 157 |
+
attention_mask = inputs["attention_mask"].to(DEVICE)
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
_, prediction_norm = model(input_ids, attention_mask)
|
| 161 |
+
prediction_fepa = z_denormalize(
|
| 162 |
+
prediction_norm.squeeze(),
|
| 163 |
+
TRAIN_LABEL_MEAN,
|
| 164 |
+
TRAIN_LABEL_STD,
|
| 165 |
+
).item()
|
| 166 |
+
|
| 167 |
+
return prediction_fepa
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -------------------------
|
| 171 |
+
# TEST
|
| 172 |
+
# -------------------------
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
parser = argparse.ArgumentParser(description="Predict formation_energy_per_atom from text")
|
| 175 |
+
parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--text",
|
| 178 |
+
type=str,
|
| 179 |
+
default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li–O bond distances ranging from 1.98–2.25 Å. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo–O bond distances ranging from 1.74–2.46 Å. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15–44°. There are a spread of Mo–O bond distances ranging from 1.77–1.82 Å. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al–O bond distances ranging from 1.88–1.95 Å. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
|
| 180 |
+
help="Input text to predict FEPA",
|
| 181 |
+
)
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
value = predict_fepa(args.text, max_length=args.max_length)
|
| 185 |
+
print(f"Predicted formation_energy_per_atom: {value:.6f}")
|
predict_is_gap_direct.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import argparse
|
| 5 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 6 |
+
|
| 7 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 8 |
+
PROJECT_CANDIDATES = [
|
| 9 |
+
SCRIPT_DIR,
|
| 10 |
+
os.path.dirname(SCRIPT_DIR),
|
| 11 |
+
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
PROJECT_DIR = None
|
| 15 |
+
for candidate in PROJECT_CANDIDATES:
|
| 16 |
+
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
|
| 17 |
+
PROJECT_DIR = candidate
|
| 18 |
+
break
|
| 19 |
+
|
| 20 |
+
if PROJECT_DIR is None:
|
| 21 |
+
raise FileNotFoundError(
|
| 22 |
+
"Could not locate project root containing llmprop_model.py. "
|
| 23 |
+
"Expected near the deployment folder."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
|
| 27 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 28 |
+
|
| 29 |
+
from llmprop_model import T5Predictor
|
| 30 |
+
|
| 31 |
+
# -------------------------
|
| 32 |
+
# CONFIG
|
| 33 |
+
# -------------------------
|
| 34 |
+
MODEL_PATH = os.path.join(PROJECT_DIR, "checkpoints", "samples", "classification", "best_checkpoint_for_is_gap_direct.pt")
|
| 35 |
+
|
| 36 |
+
TOKENIZER_PATH = os.path.join(PROJECT_DIR, "tokenizers", "t5_tokenizer_trained_on_modified_part_of_C4_and_textedge")
|
| 37 |
+
|
| 38 |
+
DEVICE = torch.device("cpu")
|
| 39 |
+
|
| 40 |
+
# -------------------------
|
| 41 |
+
# LOAD TOKENIZER
|
| 42 |
+
# -------------------------
|
| 43 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
|
| 44 |
+
|
| 45 |
+
# -------------------------
|
| 46 |
+
# LOAD MODEL
|
| 47 |
+
# -------------------------
|
| 48 |
+
base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small")
|
| 49 |
+
base_model_output_size = 512
|
| 50 |
+
|
| 51 |
+
# Match embedding matrix size to the tokenizer used during training.
|
| 52 |
+
base_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
| 53 |
+
|
| 54 |
+
model = T5Predictor(
|
| 55 |
+
base_model,
|
| 56 |
+
base_model_output_size,
|
| 57 |
+
drop_rate=0.1,
|
| 58 |
+
pooling="mean" # ✅ confirmed from your command
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# -------------------------
|
| 62 |
+
# LOAD WEIGHTS
|
| 63 |
+
# -------------------------
|
| 64 |
+
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 65 |
+
|
| 66 |
+
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
|
| 67 |
+
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
|
| 68 |
+
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
|
| 69 |
+
model.model.resize_token_embeddings(checkpoint_vocab_size, mean_resizing=False)
|
| 70 |
+
|
| 71 |
+
model.load_state_dict(state_dict, strict=False)
|
| 72 |
+
|
| 73 |
+
model.to(DEVICE)
|
| 74 |
+
model.eval()
|
| 75 |
+
|
| 76 |
+
# -------------------------
|
| 77 |
+
# PREDICT FUNCTION
|
| 78 |
+
# -------------------------
|
| 79 |
+
def predict(text, threshold=0.33):
|
| 80 |
+
|
| 81 |
+
# ❌ NO preprocessing (important)
|
| 82 |
+
|
| 83 |
+
inputs = tokenizer(
|
| 84 |
+
text,
|
| 85 |
+
return_tensors="pt",
|
| 86 |
+
truncation=True,
|
| 87 |
+
padding=True,
|
| 88 |
+
max_length=256 # ✅ from your command
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 92 |
+
attention_mask = inputs["attention_mask"].to(DEVICE)
|
| 93 |
+
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
_, predictions = model(input_ids, attention_mask)
|
| 96 |
+
|
| 97 |
+
prob = torch.sigmoid(predictions).item()
|
| 98 |
+
|
| 99 |
+
if prob > threshold:
|
| 100 |
+
return "TRUE", prob
|
| 101 |
+
else:
|
| 102 |
+
return "FALSE", prob
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# -------------------------
|
| 106 |
+
# TEST
|
| 107 |
+
# -------------------------
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
parser = argparse.ArgumentParser(description="Predict is_gap_direct from text")
|
| 110 |
+
parser.add_argument("--threshold", type=float, default=0.33, help="Decision threshold for TRUE/FALSE")
|
| 111 |
+
parser.add_argument("--text", type=str, default="Rb₂NaPrCl₆ is (Cubic) Perovskite-derived structured and crystallizes in the cubic Fm̅3m space group. Rb¹⁺ is bonded to twelve equivalent Cl¹⁻ atoms to form RbCl₁₂ cuboctahedra that share corners with twelve equivalent RbCl₁₂ cuboctahedra, faces with six equivalent RbCl₁₂ cuboctahedra, faces with four equivalent NaCl₆ octahedra, and faces with four equivalent PrCl₆ octahedra. All Rb–Cl bond lengths are 3.90 Å. Na¹⁺ is bonded to six equivalent Cl¹⁻ atoms to form NaCl₆ octahedra that share corners with six equivalent PrCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Na–Cl bond lengths are 2.76 Å. Pr³⁺ is bonded to six equivalent Cl¹⁻ atoms to form PrCl₆ octahedra that share corners with six equivalent NaCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Pr–Cl bond lengths are 2.75 Å. Cl¹⁻ is bonded in a distorted linear geometry to four equivalent Rb¹⁺, one Na¹⁺, and one Pr³⁺ atom.", help="Input text to classify")
|
| 112 |
+
args = parser.parse_args()
|
| 113 |
+
result, prob = predict(args.text, threshold=args.threshold)
|
| 114 |
+
print(result, prob)
|
predict_volume.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
| 11 |
+
from transformers.utils import logging as transformers_logging
|
| 12 |
+
|
| 13 |
+
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 14 |
+
PROJECT_CANDIDATES = [
|
| 15 |
+
SCRIPT_DIR,
|
| 16 |
+
os.path.dirname(SCRIPT_DIR),
|
| 17 |
+
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
PROJECT_DIR = None
|
| 21 |
+
for candidate in PROJECT_CANDIDATES:
|
| 22 |
+
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
|
| 23 |
+
PROJECT_DIR = candidate
|
| 24 |
+
break
|
| 25 |
+
|
| 26 |
+
if PROJECT_DIR is None:
|
| 27 |
+
raise FileNotFoundError(
|
| 28 |
+
"Could not locate project root containing llmprop_model.py. "
|
| 29 |
+
"Expected near the deployment folder."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
|
| 33 |
+
sys.path.insert(0, PROJECT_DIR)
|
| 34 |
+
|
| 35 |
+
from llmprop_model import T5Predictor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def z_denormalize(scaled_labels, labels_mean, labels_std):
|
| 39 |
+
return (scaled_labels * labels_std) + labels_mean
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# -------------------------
|
| 43 |
+
# CONFIG
|
| 44 |
+
# -------------------------
|
| 45 |
+
MODEL_PATH = os.path.join(
|
| 46 |
+
PROJECT_DIR,
|
| 47 |
+
"checkpoints",
|
| 48 |
+
"samples",
|
| 49 |
+
"regression",
|
| 50 |
+
"best_checkpoint_for_volume.pt",
|
| 51 |
+
)
|
| 52 |
+
TOKENIZER_PATH = os.path.join(
|
| 53 |
+
PROJECT_DIR,
|
| 54 |
+
"tokenizers",
|
| 55 |
+
"t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
|
| 56 |
+
)
|
| 57 |
+
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
|
| 58 |
+
PROPERTY_NAME = "volume"
|
| 59 |
+
DEVICE = torch.device("cpu")
|
| 60 |
+
|
| 61 |
+
# Silence HF/Transformers startup logs for cleaner terminal output.
|
| 62 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 63 |
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
| 64 |
+
transformers_logging.set_verbosity_error()
|
| 65 |
+
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# -------------------------
|
| 69 |
+
# PATH CHECKS
|
| 70 |
+
# -------------------------
|
| 71 |
+
if not os.path.exists(MODEL_PATH):
|
| 72 |
+
raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
|
| 73 |
+
if not os.path.exists(TOKENIZER_PATH):
|
| 74 |
+
raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")
|
| 75 |
+
if not os.path.exists(TRAIN_DATA_PATH):
|
| 76 |
+
raise FileNotFoundError(f"Training data not found: {TRAIN_DATA_PATH}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# -------------------------
|
| 80 |
+
# LOAD TRAIN LABEL STATS (z_norm)
|
| 81 |
+
# -------------------------
|
| 82 |
+
train_df = pd.read_csv(TRAIN_DATA_PATH)
|
| 83 |
+
if PROPERTY_NAME not in train_df.columns:
|
| 84 |
+
raise ValueError(f"Column '{PROPERTY_NAME}' not found in {TRAIN_DATA_PATH}")
|
| 85 |
+
|
| 86 |
+
train_labels = torch.tensor(
|
| 87 |
+
train_df[PROPERTY_NAME].dropna().to_numpy(),
|
| 88 |
+
dtype=torch.float32,
|
| 89 |
+
)
|
| 90 |
+
if train_labels.numel() == 0:
|
| 91 |
+
raise ValueError(f"No non-null values found for '{PROPERTY_NAME}' in {TRAIN_DATA_PATH}")
|
| 92 |
+
|
| 93 |
+
TRAIN_LABEL_MEAN = torch.mean(train_labels)
|
| 94 |
+
TRAIN_LABEL_STD = torch.std(train_labels)
|
| 95 |
+
if float(TRAIN_LABEL_STD) == 0.0:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
f"Standard deviation for '{PROPERTY_NAME}' is 0.0; z_norm de-normalization is undefined"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _quiet_call(fn, *args, **kwargs):
|
| 102 |
+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
|
| 103 |
+
return fn(*args, **kwargs)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# -------------------------
|
| 107 |
+
# LOAD TOKENIZER
|
| 108 |
+
# -------------------------
|
| 109 |
+
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# -------------------------
|
| 113 |
+
# LOAD MODEL
|
| 114 |
+
# -------------------------
|
| 115 |
+
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
|
| 116 |
+
base_model_output_size = 512
|
| 117 |
+
|
| 118 |
+
# Match embedding matrix size to the tokenizer used during training.
|
| 119 |
+
base_model.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
| 120 |
+
|
| 121 |
+
model = T5Predictor(
|
| 122 |
+
base_model,
|
| 123 |
+
base_model_output_size,
|
| 124 |
+
drop_rate=0.1,
|
| 125 |
+
pooling="mean",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# -------------------------
|
| 130 |
+
# LOAD WEIGHTS
|
| 131 |
+
# -------------------------
|
| 132 |
+
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)
|
| 133 |
+
|
| 134 |
+
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
|
| 135 |
+
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
|
| 136 |
+
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
|
| 137 |
+
model.model.resize_token_embeddings(checkpoint_vocab_size, mean_resizing=False)
|
| 138 |
+
|
| 139 |
+
model.load_state_dict(state_dict, strict=False)
|
| 140 |
+
model.to(DEVICE)
|
| 141 |
+
model.eval()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# -------------------------
|
| 145 |
+
# PREDICT FUNCTION
|
| 146 |
+
# -------------------------
|
| 147 |
+
def predict_volume(text, max_length=256):
|
| 148 |
+
inputs = tokenizer(
|
| 149 |
+
text,
|
| 150 |
+
return_tensors="pt",
|
| 151 |
+
truncation=True,
|
| 152 |
+
padding=True,
|
| 153 |
+
max_length=max_length,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
| 157 |
+
attention_mask = inputs["attention_mask"].to(DEVICE)
|
| 158 |
+
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
_, prediction_norm = model(input_ids, attention_mask)
|
| 161 |
+
prediction_volume = z_denormalize(
|
| 162 |
+
prediction_norm.squeeze(),
|
| 163 |
+
TRAIN_LABEL_MEAN,
|
| 164 |
+
TRAIN_LABEL_STD,
|
| 165 |
+
).item()
|
| 166 |
+
|
| 167 |
+
return prediction_volume
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# -------------------------
|
| 171 |
+
# TEST
|
| 172 |
+
# -------------------------
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
parser = argparse.ArgumentParser(description="Predict volume from text")
|
| 175 |
+
parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
|
| 176 |
+
parser.add_argument(
|
| 177 |
+
"--text",
|
| 178 |
+
type=str,
|
| 179 |
+
default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li-O bond distances ranging from 1.98-2.25 A. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo-O bond distances ranging from 1.74-2.46 A. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15-44 degrees. There are a spread of Mo-O bond distances ranging from 1.77-1.82 A. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al-O bond distances ranging from 1.88-1.95 A. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
|
| 180 |
+
help="Input text to predict volume",
|
| 181 |
+
)
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
|
| 184 |
+
value = predict_volume(args.text, max_length=args.max_length)
|
| 185 |
+
print(f"Predicted volume: {value:.6f}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
huggingface_hub>=0.23.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
pydantic>=2.7.0
|
| 5 |
+
uvicorn>=0.29.0
|
| 6 |
+
torch==2.1.0
|
| 7 |
+
pandas==2.0.1
|
| 8 |
+
transformers==4.23.1
|
| 9 |
+
sentencepiece==0.1.97
|
| 10 |
+
tokenizers==0.13.1
|
| 11 |
+
torchmetrics>=1.4.0
|
| 12 |
+
scikit-learn==1.2.2
|
| 13 |
+
tqdm==4.66.1
|