image_to_text / handler.py
virginie-d
First batch of updates
380cd6b
raw
history blame
1.77 kB
from typing import Dict, List, Any
from transformers import pipeline
import torch, PIL, transformers, triton, sentencepiece, protobuf
import torchvision, einops
import xformers, accelerate
from transformers import AutoModelForCausalLM, LlamaTokenizer
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
# cache_dir='/tmp'
)
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
# create inference pipeline
# self.pipeline = pipeline(model=model, tokenizer=tokenizer)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
inputs = data.pop("inputs", data)
gen_kwargs = {"max_length": 2048, "do_sample": False}
# pass inputs with all kwargs in data
# prediction = self.pipeline(inputs)
outputs = self.model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
prediction = self.tokenizer.decode(outputs[0])
# post process the prediction
return prediction