jpronchick's picture
updates handler
cbddb88 verified
from typing import Dict, List, Any
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftConfig, PeftModel
from huggingface_hub import login
import os
import numpy as np
import logging
class EndpointHandler():
def __init__(self, path=""):
model_name = "CNCL-Penn-State/llama3_8b_dp_effectiveness"
hftoken=os.environ["hftoken"]
# login(token=hftoken)
# np.random.seed(42)
# self.checkpoint = path #the checkpoint is in base dir
# self.config = PeftConfig.from_pretrained(self.checkpoint)
# self.tokenizer=AutoTokenizer.from_pretrained(self.config.base_model_name_or_path)
# self.inference_model=AutoModelForSequenceClassification.from_pretrained(self.config.base_model_name_or_path, num_labels=1)
# self.tokenizer.add_special_tokens({"pad_token": self.tokenizer.eos_token})
# self.tokenizer.pad_token = self.tokenizer.eos_token
# self.inference_model.config.update({"pad_token_id": self.tokenizer.eos_token_id})
# self.model = PeftModel.from_pretrained(self.inference_model, model_id=self.checkpoint,config=self.config)
self.tokenizer = AutoTokenizer.from_pretrained(model_name,token=hftoken)
self.model= pipeline("text-classification", model = model_name, tokenizer = self.tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# preprocess data
# np.random.seed(42) # sets a randomization seed for reproducibility
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
prefix = "A surprising, creative, unexpected, or interesting way to "
connector = " is: "
response = data.pop("inputs",data)
item = data.pop("item", data)
model_input = prefix+item+connector+response.lower()
logger.info(model_input)
if response is not None and item is not None:
prediction = self.model(model_input)
logger.info(prediction[0]['score'])
return [{'model_input': model_input, 'score': prediction[0]['score']}]
else:
return [{'model_input': 'ITEM OR RESPONSE NOT DETECTED', 'score': 'NA'}]