harshith20 commited on
Commit
e064628
·
1 Parent(s): c8709eb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -12
README.md CHANGED
@@ -2,9 +2,12 @@
2
  license: openrail
3
  ---
4
  ```
 
5
  import torch
6
  from transformers import AutoTokenizer, MobileBertForSequenceClassification
7
 
 
 
8
  # Load the saved model
9
  model_name = 'harshith20/Emotion_predictor'
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -12,21 +15,20 @@ model = MobileBertForSequenceClassification.from_pretrained(model_name)
12
 
13
  # Tokenize input text
14
  input_text = "I am feeling happy today"
15
- encoded_text = tokenizer.encode_plus(
16
- input_text,
17
- max_length=128,
18
- padding='max_length',
19
- truncation=True,
20
- return_attention_mask=True,
21
- return_tensors='pt'
22
- )
23
 
24
  # Predict emotion
25
  with torch.no_grad():
26
- logits = model(**encoded_text)[0]
27
- predicted_emotion = torch.argmax(logits).item()
28
- emotion_labels = ['anger', 'fear', 'joy', 'love', 'sadness', 'surprise']
29
- predicted_emotion_label = emotion_labels[predicted_emotion]
 
 
 
 
30
 
31
  print(f"Input text: {input_text}")
32
  print(f"Predicted emotion: {predicted_emotion_label}")```
 
2
  license: openrail
3
  ---
4
  ```
5
+
6
  import torch
7
  from transformers import AutoTokenizer, MobileBertForSequenceClassification
8
 
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
  # Load the saved model
12
  model_name = 'harshith20/Emotion_predictor'
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
15
 
16
  # Tokenize input text
17
  input_text = "I am feeling happy today"
18
+ input_ids = tokenizer.encode(input_text, add_special_tokens=True, truncation=True, max_length=128)
19
+ input_tensor = torch.tensor([input_ids]).to(device)
20
+
 
 
 
 
 
21
 
22
  # Predict emotion
23
  with torch.no_grad():
24
+ outputs = model(input_tensor)
25
+ logits = outputs[0]
26
+
27
+ # Get the predicted label
28
+
29
+ predicted_emotion = torch.argmax(logits, dim=1).item()
30
+ emotion_labels = {0:'sadness',1:'joy',2:'love',3:'anger',4:'fear',5:'surprise'}
31
+ predicted_emotion_label = emotion_labels[predicted_emotion]
32
 
33
  print(f"Input text: {input_text}")
34
  print(f"Predicted emotion: {predicted_emotion_label}")```