File size: 4,008 Bytes
9179af7 e4797c9 d1f3827 9179af7 d1f3827 9179af7 d1f3827 9179af7 d1f3827 9179af7 d1f3827 9179af7 d1f3827 9179af7 a963050 e4797c9 a963050 9179af7 a963050 9179af7 a963050 9179af7 d1f3827 9179af7 d1f3827 9179af7 d1f3827 9179af7 d1f3827 9179af7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import json
from bertopic import BERTopic
class EndpointHandler:
def __init__(self, model_path="SCANSKY/BERTopic_Tourism_8L"):
"""
Initialize the handler. Load the BERTopic model from Hugging Face.
"""
self.topic_model = BERTopic.load(model_path)
def preprocess(self, data):
"""
Preprocess the incoming request data.
- Extract text input from the request.
"""
try:
# Directly work with the incoming data dictionary
text_input = data.get("inputs", "")
return text_input
except Exception as e:
raise ValueError(f"Error during preprocessing: {str(e)}")
def inference(self, text_input):
"""
Perform inference using the BERTopic model.
- Combine all sentences into a single document and find shared topics.
"""
try:
# Split text into sentences (assuming one sentence per line)
sentences = text_input.strip().split('\n')
# Combine all sentences into a single document
combined_document = " ".join(sentences)
# Perform topic inference on the combined document
topics, probabilities = self.topic_model.transform([combined_document])
# Perform approximate distribution to get detailed topic contributions
appxtopics, appxprobabilities = self.topic_model.approximate_distribution(
combined_document, window=16, batch_size=16
)
doc_topic_distribution = appxtopics[0]
# Rank topics by their contribution in descending order
ranked_topics = sorted(enumerate(doc_topic_distribution), key=lambda x: x[1], reverse=True)[:10]
customlabels = self.topic_model.custom_labels_
distributedtopics = [
[f"Topic {topic_idx}", customlabels[topic_idx + 1], f"{round(contribution * 100, 2)}%"]
for topic_idx, contribution in ranked_topics
]
# Prepare the results
results = []
for topic, prob in zip(topics, probabilities):
topic_info = self.topic_model.get_topic(topic)
topic_words = [word for word, _ in topic_info] if topic_info else []
# Get custom label for the topic
if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
custom_label = self.topic_model.custom_labels_[topic + 1]
else:
custom_label = f"Topic {topic}" # Fallback label
# Get the contribution from approximate distribution
# contribution = next((contribution for idx, contribution in ranked_topics if idx == topic), 0.0)
results.append({
"topic": int(topic),
"probability": float(prob),
"top_words": topic_words[:5], # Top 5 words
"customLabel": custom_label, # Add custom label
"contribution": distributedtopics
})
return results
except Exception as e:
raise ValueError(f"Error during inference: {str(e)}")
def postprocess(self, results):
"""
Postprocess the inference results into a JSON-serializable list.
"""
return results # Directly returning the list of results
def __call__(self, data):
"""
Handle the incoming request.
"""
try:
# Preprocess the data
text_input = self.preprocess(data)
# Perform inference
results = self.inference(text_input)
# Postprocess the results
response = self.postprocess(results)
return response
except Exception as e:
return [{"error": str(e)}] # Returning error as a list with a dictionary |