| # predict.py | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import pickle | |
| model_path = 'shirleylqs/mistral-snomed-classification' | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
| with open(f'{model_path}/label_encoder.pkl', 'rb') as f: | |
| label_encoder = pickle.load(f) | |
| def predict_class(text): | |
| inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=128) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predict_class_id = logits.argmax(-1).item() | |
| predict_label = label_encoder.inverse_transform([predict_class_id])[0] | |
| return predict_label | |
| if __name__ == "__main__": | |
| text = "purulent discharge" | |
| predicted_label = predict_class(text) | |
| print(predicted_label) | |