File size: 4,446 Bytes
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f65ef
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12f65ef
e620469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import os
import sys
import argparse
from transformers import AutoTokenizer, T5EncoderModel

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_CANDIDATES = [
    SCRIPT_DIR,
    os.path.dirname(SCRIPT_DIR),
    os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
]

PROJECT_DIR = None
for candidate in PROJECT_CANDIDATES:
    if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
        PROJECT_DIR = candidate
        break

if PROJECT_DIR is None:
    raise FileNotFoundError(
        "Could not locate project root containing llmprop_model.py. "
        "Expected near the deployment folder."
    )

if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
    sys.path.insert(0, PROJECT_DIR)

from llmprop_model import T5Predictor

# -------------------------
# CONFIG
# -------------------------
MODEL_PATH = os.path.join(PROJECT_DIR, "checkpoints", "samples", "classification", "best_checkpoint_for_is_gap_direct.pt")

TOKENIZER_PATH = os.path.join(PROJECT_DIR, "tokenizers", "t5_tokenizer_trained_on_modified_part_of_C4_and_textedge")

DEVICE = torch.device("cpu")

# -------------------------
# LOAD TOKENIZER
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

# -------------------------
# LOAD MODEL
# -------------------------
base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small")
base_model_output_size = 512

# Match embedding matrix size to the tokenizer used during training.
base_model.resize_token_embeddings(len(tokenizer))

model = T5Predictor(
    base_model,
    base_model_output_size,
    drop_rate=0.1,
    pooling="mean"   # ✅ confirmed from your command
)

# -------------------------
# LOAD WEIGHTS
# -------------------------
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)

# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
    model.model.resize_token_embeddings(checkpoint_vocab_size)

model.load_state_dict(state_dict, strict=False)

model.to(DEVICE)
model.eval()

# -------------------------
# PREDICT FUNCTION
# -------------------------
def predict(text, threshold=0.33):

    # ❌ NO preprocessing (important)
    
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=256   # ✅ from your command
    )

    input_ids = inputs["input_ids"].to(DEVICE)
    attention_mask = inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        _, predictions = model(input_ids, attention_mask)

        prob = torch.sigmoid(predictions).item()

        if prob > threshold:
            return "TRUE", prob
        else:
            return "FALSE", prob


# -------------------------
# TEST
# -------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Predict is_gap_direct from text")
    parser.add_argument("--threshold", type=float, default=0.33, help="Decision threshold for TRUE/FALSE")
    parser.add_argument("--text", type=str, default="Rb₂NaPrCl₆ is (Cubic) Perovskite-derived structured and crystallizes in the cubic Fm̅3m space group. Rb¹⁺ is bonded to twelve equivalent Cl¹⁻ atoms to form RbCl₁₂ cuboctahedra that share corners with twelve equivalent RbCl₁₂ cuboctahedra, faces with six equivalent RbCl₁₂ cuboctahedra, faces with four equivalent NaCl₆ octahedra, and faces with four equivalent PrCl₆ octahedra. All Rb–Cl bond lengths are 3.90 Å. Na¹⁺ is bonded to six equivalent Cl¹⁻ atoms to form NaCl₆ octahedra that share corners with six equivalent PrCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Na–Cl bond lengths are 2.76 Å. Pr³⁺ is bonded to six equivalent Cl¹⁻ atoms to form PrCl₆ octahedra that share corners with six equivalent NaCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Pr–Cl bond lengths are 2.75 Å. Cl¹⁻ is bonded in a distorted linear geometry to four equivalent Rb¹⁺, one Na¹⁺, and one Pr³⁺ atom.", help="Input text to classify")
    args = parser.parse_args()
    result, prob = predict(args.text, threshold=args.threshold)
    print(result, prob)