Spaces:
Sleeping
Sleeping
| # ========================================================= | |
| # BERT URGENCY PREDICTION β ENGLISH | |
| # ========================================================= | |
| import os | |
| import torch | |
| import pickle | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| # ββ Load artifacts ββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = os.path.dirname(__file__) | |
| MODEL_DIR = os.path.join(BASE_DIR, "artifacts", "urgency_bert_model") | |
| # ββ Load tokenizer + model from HF Hub βββββββββββββββββββ | |
| tokenizer = BertTokenizer.from_pretrained("mohanbot799s/civicconnect-urgency-en") | |
| model = BertForSequenceClassification.from_pretrained("mohanbot799s/civicconnect-urgency-en") | |
| label_encoder = pickle.load( | |
| open(os.path.join(MODEL_DIR, "label_encoder.pkl"), "rb") | |
| ) | |
| model.eval() | |
| MAX_LENGTH = 128 | |
| # ββ Predict βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def predict_urgency( | |
| text: str, | |
| input_ids=None, | |
| attention_mask=None, | |
| ) -> dict: | |
| if input_ids is None: | |
| enc = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=False, | |
| max_length=MAX_LENGTH, | |
| ) | |
| input_ids = enc["input_ids"] | |
| attention_mask = enc["attention_mask"] | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| probs = torch.softmax(outputs.logits, dim=1) | |
| conf, pred = torch.max(probs, dim=1) | |
| confidence = conf.item() | |
| predicted_index = pred.item() | |
| urgency = label_encoder.inverse_transform([predicted_index])[0] | |
| return { | |
| "urgency": urgency, | |
| "confidence": round(confidence, 4), | |
| "class_index": predicted_index, | |
| } | |
| def get_model_and_tokenizer(): | |
| return model, tokenizer | |
| if __name__ == "__main__": | |
| print("\nBERT Urgency Prediction Test") | |
| while True: | |
| text = input("\nEnter grievance (or 'exit'): ") | |
| if text.lower() == "exit": | |
| break | |
| print(predict_urgency(text)) |