| --- |
| license: mit |
| language: |
| - en |
| library_name: pytorch |
| --- |
| # Athena- Intent |
|
|
| Classifies intent of the query for Athena |
|
|
| ## Architecture |
| distilbert-base-uncased backbone, finetuned over a multiclass classification problem |
|
|
| ### Description |
|
|
| Classifies user intent of queries into the following classes: |
| 0: Keyword Search |
| 1: Semantic Search |
| 2: Direct Question Answering |
|
|
|
|
| ## Uses |
|
|
| This model is intended to be used in Athena for performing QA on enterprise document stores. |
|
|
|
|
| ## Bias, Risks, and Limitations |
|
|
| Dataset was generated using ChatGPT (gpt-3.5-turbo). It consists of 5000 English sentences and the nature of their intent, annotated manually. |
|
|
| ## Usage |
|
|
| ``` |
| from transformers import AutoTokenizer |
| from transformers import TFDistilBertForSequenceClassification |
| import tensorflow as tf |
| |
| model = TFDistilBertForSequenceClassification.from_pretrained("sourcerersupreme/athena-intent") |
| tokenizer = AutoTokenizer.from_pretrained("sourcerersupreme/athena-intent") |
| |
| class_semantic_mapping = { |
| 0: "Keyword", |
| 1: "Semantic", |
| 2: "QA" |
| } |
| |
| # Get user input |
| user_query = "What is a CDP?" |
| |
| # Encode the user input |
| inputs = tokenizer(user_query, return_tensors="tf", truncation=True, padding=True) |
| |
| # Get model predictions |
| predictions = model(inputs)[0] |
| |
| # Get predicted class |
| predicted_class = tf.math.argmax(predictions, axis=-1) |
| |
| print(f"Predicted class: {class_semantic_mapping[int(predicted_class)]}") |
| ``` |