Spaces:
Running
Running
File size: 6,588 Bytes
e620469 28f525e e620469 12f65ef e620469 12f65ef e620469 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import argparse
import contextlib
import io
import logging
import os
import sys
import pandas as pd
import torch
from transformers import AutoTokenizer, T5EncoderModel
from transformers.utils import logging as transformers_logging
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_CANDIDATES = [
SCRIPT_DIR,
os.path.dirname(SCRIPT_DIR),
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
]
PROJECT_DIR = None
for candidate in PROJECT_CANDIDATES:
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
PROJECT_DIR = candidate
break
if PROJECT_DIR is None:
raise FileNotFoundError(
"Could not locate project root containing llmprop_model.py. "
"Expected near the deployment folder."
)
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
sys.path.insert(0, PROJECT_DIR)
from llmprop_model import T5Predictor
def z_denormalize(scaled_labels, labels_mean, labels_std):
return (scaled_labels * labels_std) + labels_mean
# -------------------------
# CONFIG
# -------------------------
MODEL_PATH = os.path.join(
PROJECT_DIR,
"checkpoints",
"samples",
"regression",
"best_checkpoint_for_volume.pt",
)
TOKENIZER_PATH = os.path.join(
PROJECT_DIR,
"tokenizers",
"t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
)
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
PROPERTY_NAME = "volume"
DEVICE = torch.device("cpu")
# Silence HF/Transformers startup logs for cleaner terminal output.
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)
# -------------------------
# PATH CHECKS
# -------------------------
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
if not os.path.exists(TOKENIZER_PATH):
raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")
TRAIN_LABEL_MEAN = torch.tensor(481.133881)
TRAIN_LABEL_STD = torch.tensor(528.036194)
def _quiet_call(fn, *args, **kwargs):
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
return fn(*args, **kwargs)
# -------------------------
# LOAD TOKENIZER
# -------------------------
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)
# -------------------------
# LOAD MODEL
# -------------------------
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
base_model_output_size = 512
# Match embedding matrix size to the tokenizer used during training.
base_model.resize_token_embeddings(len(tokenizer))
model = T5Predictor(
base_model,
base_model_output_size,
drop_rate=0.1,
pooling="mean",
)
# -------------------------
# LOAD WEIGHTS
# -------------------------
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
model.model.resize_token_embeddings(checkpoint_vocab_size)
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)
model.eval()
# -------------------------
# PREDICT FUNCTION
# -------------------------
def predict_volume(text, max_length=256):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=max_length,
)
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)
with torch.no_grad():
_, prediction_norm = model(input_ids, attention_mask)
prediction_volume = z_denormalize(
prediction_norm.squeeze(),
TRAIN_LABEL_MEAN,
TRAIN_LABEL_STD,
).item()
return prediction_volume
# -------------------------
# TEST
# -------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Predict volume from text")
parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
parser.add_argument(
"--text",
type=str,
default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li-O bond distances ranging from 1.98-2.25 A. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo-O bond distances ranging from 1.74-2.46 A. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15-44 degrees. There are a spread of Mo-O bond distances ranging from 1.77-1.82 A. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and an edgeedge with one AlO₆ octahedra. There are a spread of Al-O bond distances ranging from 1.88-1.95 A. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
help="Input text to predict volume",
)
args = parser.parse_args()
value = predict_volume(args.text, max_length=args.max_length)
print(f"Predicted volume: {value:.6f}") |