File size: 2,261 Bytes
4ec625b 1c17a19 4ec625b f4cb57e 4ec625b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import json
import torch
import pickle
from transformers import BertForSequenceClassification, BertTokenizer
class CustomPipeline:
def __init__(self, model_path):
# Load the model and tokenizer from the provided path
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.tokenizer = BertTokenizer.from_pretrained(model_path)
# Load the label encoder
pkl_path = os.path.join(model_path, 'label_encoder.pkl')
# pkl_path = os.path.abspath('label_encoder.pkl')
with open(pkl_path, 'rb') as le_file:
self.le = pickle.load(le_file)
def __call__(self, inputs):
# Tokenize the input
tokenized_inputs = self.tokenizer(inputs, return_tensors='pt', truncation=True, padding=True)
# Get the model's predictions
with torch.no_grad():
logits = self.model(**tokenized_inputs).logits
# Convert logits to probabilities
probs = logits.softmax(dim=1)
# Get the predicted label and confidence
predicted_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][predicted_class].item()
round_confidence = round(float(confidence), 2)
label = self.le.inverse_transform([predicted_class])[0]
return {"label": label, "confidence": round_confidence, "metadata": "Bert-2023-10"}
# Initialize the pipeline once (to avoid reloading the model on every invocation)
MODEL_PATH = "parvashah/bert-product-cat" # You can set an environment variable in Lambda
pipeline = CustomPipeline(MODEL_PATH)
def lambda_handler(event, context):
try:
# Extract the input text from the event (assuming the input comes as {"text": "your input here"})
text_input = event["body"]["text"]
# Use the pipeline to get the prediction
result = pipeline(text_input)
return {
"statusCode": 200,
"body": json.dumps(result),
"headers": {
"Content-Type": "application/json"
}
}
except Exception as e:
return {
"statusCode": 500,
"body": json.dumps({"error": str(e)}),
"headers": {
"Content-Type": "application/json"
}
}
|