File size: 4,008 Bytes
9179af7
e4797c9
d1f3827
 
9179af7
 
 
 
 
 
d1f3827
9179af7
 
 
 
 
 
 
 
 
 
 
d1f3827
9179af7
 
 
 
 
 
 
 
d1f3827
9179af7
 
d1f3827
9179af7
 
d1f3827
9179af7
 
 
 
 
 
 
 
a963050
 
 
e4797c9
a963050
 
9179af7
 
 
 
 
 
 
 
 
 
 
 
 
a963050
9179af7
 
 
 
 
 
a963050
9179af7
 
 
 
 
 
 
 
 
 
 
d1f3827
9179af7
 
 
 
 
 
 
d1f3827
9179af7
 
d1f3827
9179af7
 
d1f3827
9179af7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import json

from bertopic import BERTopic

class EndpointHandler:
    def __init__(self, model_path="SCANSKY/BERTopic_Tourism_8L"):
        """
        Initialize the handler. Load the BERTopic model from Hugging Face.
        """
        self.topic_model = BERTopic.load(model_path)

    def preprocess(self, data):
        """
        Preprocess the incoming request data.
        - Extract text input from the request.
        """
        try:
            # Directly work with the incoming data dictionary
            text_input = data.get("inputs", "")
            return text_input
        except Exception as e:
            raise ValueError(f"Error during preprocessing: {str(e)}")

    def inference(self, text_input):
        """
        Perform inference using the BERTopic model.
        - Combine all sentences into a single document and find shared topics.
        """
        try:
            # Split text into sentences (assuming one sentence per line)
            sentences = text_input.strip().split('\n')
    
            # Combine all sentences into a single document
            combined_document = " ".join(sentences)
    
            # Perform topic inference on the combined document
            topics, probabilities = self.topic_model.transform([combined_document])
    
            # Perform approximate distribution to get detailed topic contributions
            appxtopics, appxprobabilities = self.topic_model.approximate_distribution(
                combined_document, window=16, batch_size=16
            )
            doc_topic_distribution = appxtopics[0]
    
            # Rank topics by their contribution in descending order
            ranked_topics = sorted(enumerate(doc_topic_distribution), key=lambda x: x[1], reverse=True)[:10]
            customlabels = self.topic_model.custom_labels_

            distributedtopics = [
                [f"Topic {topic_idx}", customlabels[topic_idx + 1], f"{round(contribution * 100, 2)}%"] 
                for topic_idx, contribution in ranked_topics
            ]    
            # Prepare the results
            results = []
            for topic, prob in zip(topics, probabilities):
                topic_info = self.topic_model.get_topic(topic)
                topic_words = [word for word, _ in topic_info] if topic_info else []
    
                # Get custom label for the topic
                if hasattr(self.topic_model, "custom_labels_") and self.topic_model.custom_labels_ is not None:
                    custom_label = self.topic_model.custom_labels_[topic + 1]
                else:
                    custom_label = f"Topic {topic}"  # Fallback label
    
                # Get the contribution from approximate distribution
                # contribution = next((contribution for idx, contribution in ranked_topics if idx == topic), 0.0)
    
                results.append({
                    "topic": int(topic),
                    "probability": float(prob),
                    "top_words": topic_words[:5],  # Top 5 words
                    "customLabel": custom_label,  # Add custom label
                    "contribution": distributedtopics
                })
    
            return results
        except Exception as e:
            raise ValueError(f"Error during inference: {str(e)}")
    
    def postprocess(self, results):
        """
        Postprocess the inference results into a JSON-serializable list.
        """
        return results  # Directly returning the list of results

    def __call__(self, data):
        """
        Handle the incoming request.
        """
        try:
            # Preprocess the data
            text_input = self.preprocess(data)

            # Perform inference
            results = self.inference(text_input)

            # Postprocess the results
            response = self.postprocess(results)

            return response
        except Exception as e:
            return [{"error": str(e)}]  # Returning error as a list with a dictionary