Update README.md
Browse files
README.md
CHANGED
|
@@ -65,17 +65,20 @@ id2label = {idx: label for label, idx in label2id.items()}
|
|
| 65 |
model = AutoModelForSequenceClassification.from_pretrained("atulgupta002/banking_customer_service_query_intent_classifier")
|
| 66 |
tokenizer = AutoTokenizer.from_pretrained("atulgupta002/banking_customer_service_query_intent_classifier")
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
|
| 78 |
-
|
| 79 |
-
predicted_class = logits.argmax().item()
|
| 80 |
-
print(f"Predicted intent: {id2label[predicted_class]}")
|
| 81 |
|
|
|
|
|
|
| 65 |
model = AutoModelForSequenceClassification.from_pretrained("atulgupta002/banking_customer_service_query_intent_classifier")
|
| 66 |
tokenizer = AutoTokenizer.from_pretrained("atulgupta002/banking_customer_service_query_intent_classifier")
|
| 67 |
|
| 68 |
+
def predict(text):
|
| 69 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
| 70 |
+
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
outputs = model(**inputs)
|
| 73 |
+
logits = outputs.logits
|
| 74 |
+
predicted_class_id = logits.argmax().item()
|
| 75 |
|
| 76 |
+
return id2label[predicted_class_id]
|
|
|
|
| 77 |
|
| 78 |
+
query = "I want to apply for a new credit card"
|
| 79 |
+
print(predict(query))
|
| 80 |
+
```
|
| 81 |
|
| 82 |
+
## Sample output
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+

|