SCANSKY commited on
Commit
9179af7
·
verified ·
1 Parent(s): 584f0f6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +85 -138
handler.py CHANGED
@@ -1,149 +1,96 @@
1
- from transformers import pipeline
2
- from sklearn.preprocessing import LabelEncoder
3
- import joblib
4
- import torch
5
- import os
6
  from bertopic import BERTopic
7
- from sentence_transformers import SentenceTransformer
8
 
9
- # Debugging: Print current directory and contents
10
- print("Current working directory:", os.getcwd())
11
- print("Contents of the directory:", os.listdir())
12
-
13
- # Load the label encoder
14
- label_encoder = joblib.load('/repository/label_encoder.pkl') # Use absolute path
15
- print("Label encoder loaded successfully.")
16
-
17
- # Load the sentiment analysis model and tokenizer from Hugging Face
18
- model_name = "SCANSKY/BERTopic_Tourism_8L"
19
- sentiment_analyzer = pipeline(
20
- 'sentiment-analysis',
21
- model=model_name,
22
- tokenizer=model_name,
23
- device=0 if torch.cuda.is_available() else -1 # Use GPU if available
24
- )
25
 
26
- # Load BERTopic model
27
- embedding_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
28
- topic_model = BERTopic.load("/path/to/bertopic/model", embedding_model=embedding_model)
 
 
 
 
 
 
 
 
29
 
30
- def get_average_sentiment(positive_count, negative_count, neutral_count):
31
- total = positive_count + negative_count + neutral_count
32
- if total == 0:
33
- return "neutral"
 
 
 
 
34
 
35
- positive_pct = (positive_count / total) * 100
36
- negative_pct = (negative_count / total) * 100
37
- neutral_pct = (neutral_count / total) * 100
38
 
39
- max_sentiment = max(positive_pct, negative_pct, neutral_pct)
 
40
 
41
- if max_sentiment == positive_pct:
42
- return "positive"
43
- elif max_sentiment == negative_pct:
44
- return "negative"
45
- else:
46
- return "neutral"
47
-
48
- class EndpointHandler:
49
- def __init__(self, model_dir=None):
50
- # Model and tokenizer are loaded globally, so no need to reinitialize here
51
- # The `model_dir` argument is required by Hugging Face's inference toolkit
52
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- def preprocess(self, data):
55
- # Extract the input text from the request
56
- text = data.get("inputs", "")
57
- return text
 
 
 
58
 
59
- def inference(self, text):
60
- if not text.strip():
61
- return {"error": "Please enter some text for sentiment analysis."}
62
-
63
- # Split text into lines
64
- lines = [line.strip() for line in text.split('\n') if line.strip()]
65
-
66
- if not lines:
67
- return {"error": "Please enter valid text for sentiment analysis."}
68
-
69
- # Analyze each line for sentiment
70
- total_confidence = 0
71
- positive_count = 0
72
- negative_count = 0
73
- neutral_count = 0
74
- line_results = [] # Store results for each line
75
-
76
- for line in lines:
77
- result = sentiment_analyzer(line)
78
- predicted_label_encoded = int(result[0]['label'].split('_')[-1])
79
- predicted_label = label_encoder.inverse_transform([predicted_label_encoded])[0]
80
- confidence = result[0]['score'] * 100
81
-
82
- # Store line and its sentiment result
83
- line_results.append({
84
- 'text': line,
85
- 'sentiment': predicted_label,
86
- 'confidence': confidence
87
- })
88
-
89
- if predicted_label == 'positive':
90
- positive_count += 1
91
- elif predicted_label == 'negative':
92
- negative_count += 1
93
- else:
94
- neutral_count += 1
95
-
96
- total_confidence += confidence
97
-
98
- # Calculate averages
99
- avg_confidence = total_confidence / len(lines)
100
- positive_pct = (positive_count / len(lines)) * 100
101
- negative_pct = (negative_count / len(lines)) * 100
102
- neutral_pct = (neutral_count / len(lines)) * 100
103
-
104
- # Get average sentiment
105
- avg_sentiment = get_average_sentiment(positive_count, negative_count, neutral_count)
106
-
107
- # Perform topic inference using BERTopic's approximate_distribution
108
- merged_docs = "\n".join(lines)
109
- appxtopics, appxprobabilities = topic_model.approximate_distribution(
110
- merged_docs, window=16, batch_size=16 # Adjust window size for better alignment
111
- )
112
- doc_topic_distribution = appxtopics[0]
113
-
114
- # Rank topics by their contribution in descending order
115
- ranked_topics = sorted(enumerate(doc_topic_distribution), key=lambda x: x[1], reverse=True)[:10]
116
-
117
- # Prepare the output
118
- output = {
119
- "total_lines_analyzed": len(lines),
120
- "average_confidence": avg_confidence,
121
- "average_sentiment": avg_sentiment,
122
- "sentiment_distribution": {
123
- "positive": positive_pct,
124
- "negative": negative_pct,
125
- "neutral": neutral_pct
126
- },
127
- "line_results": line_results,
128
- "topic_distribution": {
129
- "ranked_topics": [
130
- {"topic_idx": topic_idx, "contribution": contribution}
131
- for topic_idx, contribution in ranked_topics
132
- ]
133
- }
134
- }
135
-
136
- return output
137
 
138
- def postprocess(self, output):
139
- if "error" in output:
140
- return [{"error": output["error"]}]
141
-
142
- # Return only the line-level results as a list
143
- return output["line_results"]
144
 
145
- def __call__(self, data):
146
- # Main method to handle the request
147
- text = self.preprocess(data)
148
- output = self.inference(text)
149
- return self.postprocess(output)
 
1
+ import json
 
 
 
 
2
  from bertopic import BERTopic
 
3
 
4
+ class EndpointHandler:
5
+ def __init__(self, model_path="SCANSKY/BERTopic_Tourism_8L"):
6
+ """
7
+ Initialize the handler. Load the BERTopic model from Hugging Face.
8
+ """
9
+ self.topic_model = BERTopic.load(model_path)
 
 
 
 
 
 
 
 
 
 
10
 
11
+ def preprocess(self, data):
12
+ """
13
+ Preprocess the incoming request data.
14
+ - Extract text input from the request.
15
+ """
16
+ try:
17
+ # Directly work with the incoming data dictionary
18
+ text_input = data.get("inputs", "")
19
+ return text_input
20
+ except Exception as e:
21
+ raise ValueError(f"Error during preprocessing: {str(e)}")
22
 
23
+ def inference(self, text_input):
24
+ """
25
+ Perform inference using the BERTopic model.
26
+ - Combine all sentences into a single document and find shared topics.
27
+ """
28
+ try:
29
+ # Split text into sentences (assuming one sentence per line)
30
+ sentences = text_input.strip().split('\n')
31
 
32
+ # Combine all sentences into a single document
33
+ combined_document = " ".join(sentences)
 
34
 
35
+ # Perform topic inference on the combined document
36
+ topics, probabilities = self.topic_model.transform([combined_document])
37
 
38
+ # Perform approximate distribution to get detailed topic contributions
39
+ appxtopics, appxprobabilities = self.topic_model.approximate_distribution(
40
+ combined_document, window=16, batch_size=16
41
+ )
42
+ doc_topic_distribution = appxtopics[0]
43
+
44
+ # Rank topics by their contribution in descending order
45
+ ranked_topics = sorted(enumerate(doc_topic_distribution), key=lambda x: x[1], reverse=True)[:10]
46
+
47
+ # Prepare the results
48
+ results = []
49
+ for topic, prob in zip(topics, probabilities):
50
+ topic_info = self.topic_model.get_topic(topic)
51
+ topic_words = [word for word, _ in topic_info] if topic_info else []
52
+
53
+ # Get custom label for the topic
54
+ if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
55
+ custom_label = self.topic_model.custom_labels_[topic + 1]
56
+ else:
57
+ custom_label = f"Topic {topic}" # Fallback label
58
+
59
+ # Get the contribution from approximate distribution
60
+ contribution = next((contribution for idx, contribution in ranked_topics if idx == topic), 0.0)
61
+
62
+ results.append({
63
+ "topic": int(topic),
64
+ "probability": float(prob),
65
+ "top_words": topic_words[:5], # Top 5 words
66
+ "customLabel": custom_label, # Add custom label
67
+ "contribution": float(contribution) # Add contribution from approximate distribution
68
+ })
69
+
70
+ return results
71
+ except Exception as e:
72
+ raise ValueError(f"Error during inference: {str(e)}")
73
+
74
+ def postprocess(self, results):
75
+ """
76
+ Postprocess the inference results into a JSON-serializable list.
77
+ """
78
+ return results # Directly returning the list of results
79
 
80
+ def __call__(self, data):
81
+ """
82
+ Handle the incoming request.
83
+ """
84
+ try:
85
+ # Preprocess the data
86
+ text_input = self.preprocess(data)
87
 
88
+ # Perform inference
89
+ results = self.inference(text_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Postprocess the results
92
+ response = self.postprocess(results)
 
 
 
 
93
 
94
+ return response
95
+ except Exception as e:
96
+ return [{"error": str(e)}] # Returning error as a list with a dictionary