Spaces:
Running
Running
File size: 4,446 Bytes
e620469 12f65ef e620469 12f65ef e620469 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import torch
import os
import sys
import argparse
from transformers import AutoTokenizer, T5EncoderModel
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_CANDIDATES = [
SCRIPT_DIR,
os.path.dirname(SCRIPT_DIR),
os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"),
]
PROJECT_DIR = None
for candidate in PROJECT_CANDIDATES:
if os.path.exists(os.path.join(candidate, "llmprop_model.py")):
PROJECT_DIR = candidate
break
if PROJECT_DIR is None:
raise FileNotFoundError(
"Could not locate project root containing llmprop_model.py. "
"Expected near the deployment folder."
)
if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path:
sys.path.insert(0, PROJECT_DIR)
from llmprop_model import T5Predictor
# -------------------------
# CONFIG
# -------------------------
MODEL_PATH = os.path.join(PROJECT_DIR, "checkpoints", "samples", "classification", "best_checkpoint_for_is_gap_direct.pt")
TOKENIZER_PATH = os.path.join(PROJECT_DIR, "tokenizers", "t5_tokenizer_trained_on_modified_part_of_C4_and_textedge")
DEVICE = torch.device("cpu")
# -------------------------
# LOAD TOKENIZER
# -------------------------
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
# -------------------------
# LOAD MODEL
# -------------------------
base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small")
base_model_output_size = 512
# Match embedding matrix size to the tokenizer used during training.
base_model.resize_token_embeddings(len(tokenizer))
model = T5Predictor(
base_model,
base_model_output_size,
drop_rate=0.1,
pooling="mean" # ✅ confirmed from your command
)
# -------------------------
# LOAD WEIGHTS
# -------------------------
state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
# Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint.
checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0]
if model.model.shared.weight.shape[0] != checkpoint_vocab_size:
model.model.resize_token_embeddings(checkpoint_vocab_size)
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)
model.eval()
# -------------------------
# PREDICT FUNCTION
# -------------------------
def predict(text, threshold=0.33):
# ❌ NO preprocessing (important)
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256 # ✅ from your command
)
input_ids = inputs["input_ids"].to(DEVICE)
attention_mask = inputs["attention_mask"].to(DEVICE)
with torch.no_grad():
_, predictions = model(input_ids, attention_mask)
prob = torch.sigmoid(predictions).item()
if prob > threshold:
return "TRUE", prob
else:
return "FALSE", prob
# -------------------------
# TEST
# -------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Predict is_gap_direct from text")
parser.add_argument("--threshold", type=float, default=0.33, help="Decision threshold for TRUE/FALSE")
parser.add_argument("--text", type=str, default="Rb₂NaPrCl₆ is (Cubic) Perovskite-derived structured and crystallizes in the cubic Fm̅3m space group. Rb¹⁺ is bonded to twelve equivalent Cl¹⁻ atoms to form RbCl₁₂ cuboctahedra that share corners with twelve equivalent RbCl₁₂ cuboctahedra, faces with six equivalent RbCl₁₂ cuboctahedra, faces with four equivalent NaCl₆ octahedra, and faces with four equivalent PrCl₆ octahedra. All Rb–Cl bond lengths are 3.90 Å. Na¹⁺ is bonded to six equivalent Cl¹⁻ atoms to form NaCl₆ octahedra that share corners with six equivalent PrCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Na–Cl bond lengths are 2.76 Å. Pr³⁺ is bonded to six equivalent Cl¹⁻ atoms to form PrCl₆ octahedra that share corners with six equivalent NaCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Pr–Cl bond lengths are 2.75 Å. Cl¹⁻ is bonded in a distorted linear geometry to four equivalent Rb¹⁺, one Na¹⁺, and one Pr³⁺ atom.", help="Input text to classify")
args = parser.parse_args()
result, prob = predict(args.text, threshold=args.threshold)
print(result, prob) |