Update handler.py
Browse files- handler.py +11 -14
handler.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from transformers import pipeline
|
|
|
|
| 2 |
import joblib
|
| 3 |
import torch
|
| 4 |
import os
|
|
@@ -8,7 +9,7 @@ print("Current working directory:", os.getcwd())
|
|
| 8 |
print("Contents of the directory:", os.listdir())
|
| 9 |
|
| 10 |
# Load the label encoder
|
| 11 |
-
label_encoder = joblib.load('label_encoder.pkl') #
|
| 12 |
print("Label encoder loaded successfully.")
|
| 13 |
|
| 14 |
# Load the model and tokenizer from Hugging Face
|
|
@@ -39,16 +40,14 @@ def get_average_sentiment(positive_count, negative_count, neutral_count):
|
|
| 39 |
return "neutral"
|
| 40 |
|
| 41 |
class EndpointHandler:
|
| 42 |
-
def __init__(self):
|
| 43 |
-
#
|
|
|
|
| 44 |
pass
|
| 45 |
|
| 46 |
def preprocess(self, data):
|
| 47 |
# Extract the input text from the request
|
| 48 |
-
|
| 49 |
-
text = data.get("inputs", "")
|
| 50 |
-
else:
|
| 51 |
-
text = data # Fallback if data is not a dictionary
|
| 52 |
return text
|
| 53 |
|
| 54 |
def inference(self, text):
|
|
@@ -116,16 +115,14 @@ class EndpointHandler:
|
|
| 116 |
|
| 117 |
def postprocess(self, output):
|
| 118 |
if "error" in output:
|
| 119 |
-
return {"error": output["error"]}
|
| 120 |
|
| 121 |
-
# Return the
|
| 122 |
-
return output
|
|
|
|
| 123 |
|
| 124 |
def __call__(self, data):
|
| 125 |
# Main method to handle the request
|
| 126 |
text = self.preprocess(data)
|
| 127 |
output = self.inference(text)
|
| 128 |
-
return self.postprocess(output)
|
| 129 |
-
|
| 130 |
-
# Create an instance of the handler
|
| 131 |
-
handler = EndpointHandler()
|
|
|
|
| 1 |
from transformers import pipeline
|
| 2 |
+
from sklearn.preprocessing import LabelEncoder
|
| 3 |
import joblib
|
| 4 |
import torch
|
| 5 |
import os
|
|
|
|
| 9 |
print("Contents of the directory:", os.listdir())
|
| 10 |
|
| 11 |
# Load the label encoder
|
| 12 |
+
label_encoder = joblib.load('/repository/label_encoder.pkl') # Use absolute path
|
| 13 |
print("Label encoder loaded successfully.")
|
| 14 |
|
| 15 |
# Load the model and tokenizer from Hugging Face
|
|
|
|
| 40 |
return "neutral"
|
| 41 |
|
| 42 |
class EndpointHandler:
|
| 43 |
+
def __init__(self, model_dir=None):
|
| 44 |
+
# Model and tokenizer are loaded globally, so no need to reinitialize here
|
| 45 |
+
# The `model_dir` argument is required by Hugging Face's inference toolkit
|
| 46 |
pass
|
| 47 |
|
| 48 |
def preprocess(self, data):
|
| 49 |
# Extract the input text from the request
|
| 50 |
+
text = data.get("inputs", "")
|
|
|
|
|
|
|
|
|
|
| 51 |
return text
|
| 52 |
|
| 53 |
def inference(self, text):
|
|
|
|
| 115 |
|
| 116 |
def postprocess(self, output):
|
| 117 |
if "error" in output:
|
| 118 |
+
return [{"error": output["error"]}]
|
| 119 |
|
| 120 |
+
# Return only the line-level results as a list
|
| 121 |
+
return output["line_results"]
|
| 122 |
+
|
| 123 |
|
| 124 |
def __call__(self, data):
|
| 125 |
# Main method to handle the request
|
| 126 |
text = self.preprocess(data)
|
| 127 |
output = self.inference(text)
|
| 128 |
+
return self.postprocess(output)
|
|
|
|
|
|
|
|
|