logasanjeev commited on
Commit
513a891
·
verified ·
1 Parent(s): 40d7069

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +93 -12
inference.py CHANGED
@@ -1,18 +1,99 @@
1
- from transformers import BertForSequenceClassification, BertTokenizer
2
  import torch
3
  import json
4
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def predict(text):
 
7
  repo_id = "logasanjeev/goemotions-bert"
8
- model = BertForSequenceClassification.from_pretrained(repo_id)
9
- tokenizer = BertTokenizer.from_pretrained(repo_id)
10
- thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
11
- thresholds_data = json.loads(requests.get(thresholds_url).text)
12
- emotion_labels = thresholds_data["emotion_labels"]
13
- thresholds = thresholds_data["thresholds"]
14
- encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  with torch.no_grad():
16
- logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
17
- predictions = [{"label": emotion_labels[i], "score": float(logit)} for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
18
- return sorted(predictions, key=lambda x: x["score"], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import json
3
  import requests
4
+ import re
5
+ import emoji
6
+ from transformers import BertForSequenceClassification, BertTokenizer
7
+
8
+ def preprocess_text(text):
9
+ """Preprocess the input text to match training conditions."""
10
+ text = re.sub(r'u/\w+', '[USER]', text)
11
+ text = re.sub(r'r/\w+', '[SUBREDDIT]', text)
12
+ text = re.sub(r'http[s]?://\S+', '[URL]', text)
13
+ text = emoji.demojize(text, delimiters=(" ", " "))
14
+ text = text.lower()
15
+ return text
16
 
17
+ def load_model_and_resources():
18
+ """Load the model, tokenizer, emotion labels, and thresholds from Hugging Face."""
19
  repo_id = "logasanjeev/goemotions-bert"
20
+
21
+ try:
22
+ model = BertForSequenceClassification.from_pretrained(repo_id)
23
+ tokenizer = BertTokenizer.from_pretrained(repo_id)
24
+ except Exception as e:
25
+ raise RuntimeError(f"Error loading model/tokenizer: {str(e)}")
26
+
27
+ thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/optimized_thresholds.json"
28
+ try:
29
+ thresholds_data = json.loads(requests.get(thresholds_url).text)
30
+ if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data):
31
+ raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.")
32
+ emotion_labels = thresholds_data["emotion_labels"]
33
+ thresholds = thresholds_data["thresholds"]
34
+ except Exception as e:
35
+ raise RuntimeError(f"Error loading thresholds: {str(e)}")
36
+
37
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+ model.to(device)
39
+ model.eval()
40
+
41
+ return model, tokenizer, emotion_labels, thresholds, device
42
+
43
+ MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = None, None, None, None, None
44
+
45
+ def predict_emotions(text):
46
+ """Predict emotions for the given text using the GoEmotions BERT model.
47
+
48
+ Args:
49
+ text (str): The input text to analyze.
50
+
51
+ Returns:
52
+ tuple: (predictions, processed_text)
53
+ - predictions (str): Formatted string of predicted emotions and their confidence scores.
54
+ - processed_text (str): The preprocessed input text.
55
+ """
56
+ global MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE
57
+
58
+ if MODEL is None:
59
+ MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = load_model_and_resources()
60
+
61
+ processed_text = preprocess_text(text)
62
+
63
+ encodings = TOKENIZER(
64
+ processed_text,
65
+ padding='max_length',
66
+ truncation=True,
67
+ max_length=128,
68
+ return_tensors='pt'
69
+ )
70
+
71
+ input_ids = encodings['input_ids'].to(DEVICE)
72
+ attention_mask = encodings['attention_mask'].to(DEVICE)
73
+
74
  with torch.no_grad():
75
+ outputs = MODEL(input_ids, attention_mask=attention_mask)
76
+ logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
77
+
78
+ predictions = []
79
+ for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)):
80
+ if logit >= thresh:
81
+ predictions.append((EMOTION_LABELS[i], round(logit, 4)))
82
+
83
+ predictions.sort(key=lambda x: x[1], reverse=True)
84
+
85
+ result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted."
86
+ return result, processed_text
87
+
88
+ if __name__ == "__main__":
89
+ import argparse
90
+
91
+ parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT model.")
92
+ parser.add_argument("text", type=str, help="The input text to analyze for emotions.")
93
+ args = parser.parse_args()
94
+
95
+ result, processed = predict_emotions(args.text)
96
+ print(f"Input: {args.text}")
97
+ print(f"Processed: {processed}")
98
+ print("Predicted Emotions:")
99
+ print(result)