bert-product-cat / handler.py
parvashah-create
reqs
1c17a19
import os
import json
import torch
import pickle
from transformers import BertForSequenceClassification, BertTokenizer
class CustomPipeline:
def __init__(self, model_path):
# Load the model and tokenizer from the provided path
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.tokenizer = BertTokenizer.from_pretrained(model_path)
# Load the label encoder
pkl_path = os.path.join(model_path, 'label_encoder.pkl')
# pkl_path = os.path.abspath('label_encoder.pkl')
with open(pkl_path, 'rb') as le_file:
self.le = pickle.load(le_file)
def __call__(self, inputs):
# Tokenize the input
tokenized_inputs = self.tokenizer(inputs, return_tensors='pt', truncation=True, padding=True)
# Get the model's predictions
with torch.no_grad():
logits = self.model(**tokenized_inputs).logits
# Convert logits to probabilities
probs = logits.softmax(dim=1)
# Get the predicted label and confidence
predicted_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][predicted_class].item()
round_confidence = round(float(confidence), 2)
label = self.le.inverse_transform([predicted_class])[0]
return {"label": label, "confidence": round_confidence, "metadata": "Bert-2023-10"}
# Initialize the pipeline once (to avoid reloading the model on every invocation)
MODEL_PATH = "parvashah/bert-product-cat" # You can set an environment variable in Lambda
pipeline = CustomPipeline(MODEL_PATH)
def lambda_handler(event, context):
try:
# Extract the input text from the event (assuming the input comes as {"text": "your input here"})
text_input = event["body"]["text"]
# Use the pipeline to get the prediction
result = pipeline(text_input)
return {
"statusCode": 200,
"body": json.dumps(result),
"headers": {
"Content-Type": "application/json"
}
}
except Exception as e:
return {
"statusCode": 500,
"body": json.dumps({"error": str(e)}),
"headers": {
"Content-Type": "application/json"
}
}