File size: 6,625 Bytes
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28f525e
 
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f65ef
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f65ef
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import argparse
import os
import sys
import logging
import io
import contextlib

import pandas as pd
import torch
from transformers import AutoTokenizer, T5EncoderModel
from transformers.utils import logging as transformers_logging

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_CANDIDATES = [
    SCRIPT_DIR,
    os.path.dirname(SCRIPT_DIR),
    os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
]

PROJECT_DIR = None
for candidate in PROJECT_CANDIDATES:
    if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
        PROJECT_DIR = candidate
        break

if PROJECT_DIR is None:
    raise FileNotFoundError(
        "Could not locate project root containing llmprop_model.py. "
        "Expected near the deployment folder."
    )

if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
    sys.path.insert(0, PROJECT_DIR)

from llmprop_model import T5Predictor


def z_denormalize(scaled_labels, labels_mean, labels_std):
    return (scaled_labels * labels_std) + labels_mean

# -------------------------
# CONFIG
# -------------------------
MODEL_PATH = os.path.join(
    PROJECT_DIR,
    "checkpoints",
    "samples",
    "regression",
    "best_checkpoint_for_energy_per_atom.pt",
)
TOKENIZER_PATH = os.path.join(
    PROJECT_DIR,
    "tokenizers",
    "t5_tokenizer_trained_on_modified_part_of_C4_and_textedge",
)
TRAIN_DATA_PATH = os.path.join(PROJECT_DIR, "data", "samples", "train_data.csv")
PROPERTY_NAME = "energy_per_atom"
DEVICE = torch.device("cpu")

# Silence HF/Transformers startup logs for cleaner terminal output.
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub").setLevel(logging.ERROR)


# -------------------------
# PATH CHECKS
# -------------------------
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Checkpoint not found: {MODEL_PATH}")
if not os.path.exists(TOKENIZER_PATH):
    raise FileNotFoundError(f"Tokenizer path not found: {TOKENIZER_PATH}")

TRAIN_LABEL_MEAN = torch.tensor(-40.019421)
TRAIN_LABEL_STD = torch.tensor(17.998217)

def _quiet_call(fn, *args, **kwargs):
    with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
        return fn(*args, **kwargs)


# -------------------------
# LOAD TOKENIZER
# -------------------------
tokenizer = _quiet_call(AutoTokenizer.from_pretrained, TOKENIZER_PATH)


# -------------------------
# LOAD MODEL
# -------------------------
base_model = _quiet_call(T5EncoderModel.from_pretrained, "google/t5-v1_1-small")
base_model_output_size = 512

# Match embedding matrix size to the tokenizer used during training.
base_model.resize_token_embeddings(len(tokenizer))

model = T5Predictor(
    base_model,
    base_model_output_size,
    drop_rate=0.1,
    pooling="mean",
)


# -------------------------
# LOAD WEIGHTS
# -------------------------
state_dict = _quiet_call(torch.load, MODEL_PATH, map_location=DEVICE)

# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
    model.model.resize_token_embeddings(checkpoint_vocab_size)

model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)
model.eval()


# -------------------------
# PREDICT FUNCTION
# -------------------------
def predict_epa(text, max_length=256):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_length,
    )

    input_ids = inputs["input_ids"].to(DEVICE)
    attention_mask = inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        _, prediction_norm = model(input_ids, attention_mask)
        prediction_epa = z_denormalize(
            prediction_norm.squeeze(),
            TRAIN_LABEL_MEAN,
            TRAIN_LABEL_STD,
        ).item()

    return prediction_epa


# -------------------------
# TEST
# -------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Predict energy_per_atom from text")
    parser.add_argument("--max_length", type=int, default=256, help="Tokenizer max length")
    parser.add_argument(
        "--text",
        type=str,
        default="A simple cubic crystalLiAl(MoO₄)₂ crystallizes in the triclinic P̅1 space group. Li¹⁺ is bonded in a 5-coordinate geometry to five O²⁻ atoms. There are a spread of Li–O bond distances ranging from 1.98–2.25 Å. There are two inequivalent Mo⁶⁺ sites. In the first Mo⁶⁺ site, Mo⁶⁺ is bonded in a 4-coordinate geometry to five O²⁻ atoms. There are a spread of Mo–O bond distances ranging from 1.74–2.46 Å. In the second Mo⁶⁺ site, Mo⁶⁺ is bonded to four O²⁻ atoms to form MoO₄ tetrahedra that share corners with three equivalent AlO₆ octahedra. The corner-sharing octahedral tilt angles range from 15–44°. There are a spread of Mo–O bond distances ranging from 1.77–1.82 Å. Al³⁺ is bonded to six O²⁻ atoms to form AlO₆ octahedra that share corners with three equivalent MoO₄ tetrahedra and  an edgeedge with one AlO₆ octahedra. There are a spread of Al–O bond distances ranging from 1.88–1.95 Å. There are eight inequivalent O²⁻ sites. In the first O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Li¹⁺, one Mo⁶⁺, and one Al³⁺ atom. In the second O²⁻ site, O²⁻ is bonded in a distorted trigonal planar geometry to one Mo⁶⁺ and two equivalent Al³⁺ atoms. In the third O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fourth O²⁻ site, O²⁻ is bonded in a linear geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the fifth O²⁻ site, O²⁻ is bonded in a linear geometry to one Mo⁶⁺ and one Al³⁺ atom. In the sixth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Li¹⁺ and one Mo⁶⁺ atom. In the seventh O²⁻ site, O²⁻ is bonded in a 4-coordinate geometry to one Li¹⁺, two equivalent Mo⁶⁺, and one Al³⁺ atom. In the eighth O²⁻ site, O²⁻ is bonded in a bent 150 degrees geometry to one Mo⁶⁺ and one Al³⁺ atom. with atoms arranged periodically and stable at room temperature.",
        help="Input text to predict EPA",
    )
    args = parser.parse_args()

    value = predict_epa(args.text, max_length=args.max_length)
    print(f"Predicted energy_per_atom: {value:.6f}")