Spaces:
Running
Running
| import torch | |
| import os | |
| import sys | |
| import argparse | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| PROJECT_CANDIDATES = [ | |
| SCRIPT_DIR, | |
| os.path.dirname(SCRIPT_DIR), | |
| os.path.join(os.path.dirname(SCRIPT_DIR), "LLM-Prop"), | |
| ] | |
| PROJECT_DIR = None | |
| for candidate in PROJECT_CANDIDATES: | |
| if os.path.exists(os.path.join(candidate, "llmprop_model.py")): | |
| PROJECT_DIR = candidate | |
| break | |
| if PROJECT_DIR is None: | |
| raise FileNotFoundError( | |
| "Could not locate project root containing llmprop_model.py. " | |
| "Expected near the deployment folder." | |
| ) | |
| if os.path.isdir(PROJECT_DIR) and PROJECT_DIR not in sys.path: | |
| sys.path.insert(0, PROJECT_DIR) | |
| from llmprop_model import T5Predictor | |
| # ------------------------- | |
| # CONFIG | |
| # ------------------------- | |
| MODEL_PATH = os.path.join(PROJECT_DIR, "checkpoints", "samples", "classification", "best_checkpoint_for_is_gap_direct.pt") | |
| TOKENIZER_PATH = os.path.join(PROJECT_DIR, "tokenizers", "t5_tokenizer_trained_on_modified_part_of_C4_and_textedge") | |
| DEVICE = torch.device("cpu") | |
| # ------------------------- | |
| # LOAD TOKENIZER | |
| # ------------------------- | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) | |
| # ------------------------- | |
| # LOAD MODEL | |
| # ------------------------- | |
| base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small") | |
| base_model_output_size = 512 | |
| # Match embedding matrix size to the tokenizer used during training. | |
| base_model.resize_token_embeddings(len(tokenizer)) | |
| model = T5Predictor( | |
| base_model, | |
| base_model_output_size, | |
| drop_rate=0.1, | |
| pooling="mean" # ✅ confirmed from your command | |
| ) | |
| # ------------------------- | |
| # LOAD WEIGHTS | |
| # ------------------------- | |
| state_dict = torch.load(MODEL_PATH, map_location=DEVICE) | |
| # Some checkpoints were trained with an extra tokenizer token; align embedding size to checkpoint. | |
| checkpoint_vocab_size = state_dict["model.shared.weight"].shape[0] | |
| if model.model.shared.weight.shape[0] != checkpoint_vocab_size: | |
| model.model.resize_token_embeddings(checkpoint_vocab_size) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(DEVICE) | |
| model.eval() | |
| # ------------------------- | |
| # PREDICT FUNCTION | |
| # ------------------------- | |
| def predict(text, threshold=0.33): | |
| # ❌ NO preprocessing (important) | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=256 # ✅ from your command | |
| ) | |
| input_ids = inputs["input_ids"].to(DEVICE) | |
| attention_mask = inputs["attention_mask"].to(DEVICE) | |
| with torch.no_grad(): | |
| _, predictions = model(input_ids, attention_mask) | |
| prob = torch.sigmoid(predictions).item() | |
| if prob > threshold: | |
| return "TRUE", prob | |
| else: | |
| return "FALSE", prob | |
| # ------------------------- | |
| # TEST | |
| # ------------------------- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Predict is_gap_direct from text") | |
| parser.add_argument("--threshold", type=float, default=0.33, help="Decision threshold for TRUE/FALSE") | |
| parser.add_argument("--text", type=str, default="Rb₂NaPrCl₆ is (Cubic) Perovskite-derived structured and crystallizes in the cubic Fm̅3m space group. Rb¹⁺ is bonded to twelve equivalent Cl¹⁻ atoms to form RbCl₁₂ cuboctahedra that share corners with twelve equivalent RbCl₁₂ cuboctahedra, faces with six equivalent RbCl₁₂ cuboctahedra, faces with four equivalent NaCl₆ octahedra, and faces with four equivalent PrCl₆ octahedra. All Rb–Cl bond lengths are 3.90 Å. Na¹⁺ is bonded to six equivalent Cl¹⁻ atoms to form NaCl₆ octahedra that share corners with six equivalent PrCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Na–Cl bond lengths are 2.76 Å. Pr³⁺ is bonded to six equivalent Cl¹⁻ atoms to form PrCl₆ octahedra that share corners with six equivalent NaCl₆ octahedra and faces with eight equivalent RbCl₁₂ cuboctahedra. The corner-sharing octahedra are not tilted. All Pr–Cl bond lengths are 2.75 Å. Cl¹⁻ is bonded in a distorted linear geometry to four equivalent Rb¹⁺, one Na¹⁺, and one Pr³⁺ atom.", help="Input text to classify") | |
| args = parser.parse_args() | |
| result, prob = predict(args.text, threshold=args.threshold) | |
| print(result, prob) |