| import json |
|
|
| from bertopic import BERTopic |
|
|
| class EndpointHandler: |
| def __init__(self, model_path="SCANSKY/BERTopic_Tourism_8L"): |
| """ |
| Initialize the handler. Load the BERTopic model from Hugging Face. |
| """ |
| self.topic_model = BERTopic.load(model_path) |
|
|
| def preprocess(self, data): |
| """ |
| Preprocess the incoming request data. |
| - Extract text input from the request. |
| """ |
| try: |
| |
| text_input = data.get("inputs", "") |
| return text_input |
| except Exception as e: |
| raise ValueError(f"Error during preprocessing: {str(e)}") |
|
|
| def inference(self, text_input): |
| """ |
| Perform inference using the BERTopic model. |
| - Combine all sentences into a single document and find shared topics. |
| """ |
| try: |
| |
| sentences = text_input.strip().split('\n') |
| |
| |
| combined_document = " ".join(sentences) |
| |
| |
| topics, probabilities = self.topic_model.transform([combined_document]) |
| |
| |
| appxtopics, appxprobabilities = self.topic_model.approximate_distribution( |
| combined_document, window=16, batch_size=16 |
| ) |
| doc_topic_distribution = appxtopics[0] |
| |
| |
| ranked_topics = sorted(enumerate(doc_topic_distribution), key=lambda x: x[1], reverse=True)[:10] |
| customlabels = self.topic_model.custom_labels_ |
|
|
| distributedtopics = [ |
| [f"Topic {topic_idx}", customlabels[topic_idx + 1], f"{round(contribution * 100, 2)}%"] |
| for topic_idx, contribution in ranked_topics |
| ] |
| |
| results = [] |
| for topic, prob in zip(topics, probabilities): |
| topic_info = self.topic_model.get_topic(topic) |
| topic_words = [word for word, _ in topic_info] if topic_info else [] |
| |
| |
| if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None: |
| custom_label = self.topic_model.custom_labels_[topic + 1] |
| else: |
| custom_label = f"Topic {topic}" |
| |
| |
| |
| |
| results.append({ |
| "topic": int(topic), |
| "probability": float(prob), |
| "top_words": topic_words[:5], |
| "customLabel": custom_label, |
| "contribution": distributedtopics |
| }) |
| |
| return results |
| except Exception as e: |
| raise ValueError(f"Error during inference: {str(e)}") |
| |
| def postprocess(self, results): |
| """ |
| Postprocess the inference results into a JSON-serializable list. |
| """ |
| return results |
|
|
| def __call__(self, data): |
| """ |
| Handle the incoming request. |
| """ |
| try: |
| |
| text_input = self.preprocess(data) |
|
|
| |
| results = self.inference(text_input) |
|
|
| |
| response = self.postprocess(results) |
|
|
| return response |
| except Exception as e: |
| return [{"error": str(e)}] |