| def predict_with_threshold(text, model_dir="cross-sectional-classifier", threshold=None, threshold_type=None): |
| """ |
| Make a prediction using the saved model and custom threshold. |
| |
| Args: |
| text: Input text to classify |
| model_dir: Directory where the model and threshold are saved |
| threshold: Custom threshold to use (if None, loads from saved config) |
| threshold_type: Type of threshold to use ('f1' or 'balanced') if loading from all_thresholds.json |
| |
| Returns: |
| Dictionary with prediction results |
| """ |
| |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| import torch |
| import json |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
| |
| threshold_source = "custom" |
| |
| |
| if threshold is None: |
| |
| if threshold_type is not None: |
| |
| try: |
| with open(f"{model_dir}/all_thresholds.json", "r") as f: |
| all_thresholds = json.load(f) |
| if threshold_type in all_thresholds: |
| threshold = all_thresholds[threshold_type]["threshold"] |
| threshold_source = f"all_thresholds.json ({threshold_type})" |
| else: |
| print(f"Threshold type '{threshold_type}' not found. Available types: {list(all_thresholds.keys())}") |
| except FileNotFoundError: |
| pass |
| |
| |
| if threshold is None: |
| try: |
| with open(f"{model_dir}/threshold_config.json", "r") as f: |
| config = json.load(f) |
| threshold = config["threshold"] |
| threshold_source = "threshold_config.json" |
| except FileNotFoundError: |
| |
| threshold = 0.5 |
| threshold_source = "default" |
| print("No threshold configuration found. Using default threshold of 0.5.") |
| |
| |
| inputs = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt") |
| |
| |
| model.eval() |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| |
| probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist() |
| |
| |
| if isinstance(probs, list): |
| prediction = 1 if probs[1] > threshold else 0 |
| class_probs = { |
| "Cross-sectional": probs[1], |
| "Other": probs[0] |
| } |
| else: |
| prediction = 1 if probs > threshold else 0 |
| class_probs = { |
| "Cross-sectional": probs, |
| "Other": 1 - probs |
| } |
| |
| |
| label_names = {1: "Cross-sectional", 0: "Other"} |
| |
| return { |
| "prediction": label_names[prediction], |
| "probabilities": class_probs, |
| "threshold_used": threshold, |
| "threshold_source": threshold_source |
| } |
|
|
|
|
| sample_text = 'This is a cross-sectional study that aims to investigate the relationship between smoking and lung cancer.' |
| result = predict_with_threshold(sample_text) |