| | |
| | import torch |
| | from transformers import pipeline, AutoProcessor, Blip2ForConditionalGeneration |
| | import os |
| | """import base64 |
| | from io import BytesIO |
| | from PIL import Image""" |
| |
|
| | |
| | device = 0 if torch.cuda.is_available() else -1 |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | blip2_proc = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") |
| | |
| | blip2 = Blip2ForConditionalGeneration.from_pretrained(os.path.join(path, "sharded"), device_map="auto", load_in_8bit=True) |
| | |
| |
|
| | def __call__(self, data): |
| | |
| | """b64_img = data.pop("b64", data) |
| | lang = data.pop("lang", None) |
| | decode = data.pop("decode", None) |
| | |
| | #prepare image |
| | im_bytes = base64.b64decode(b64_img) # im_bytes is a binary image |
| | im_file = BytesIO(im_bytes) # convert image to file-like object |
| | image = Image.open(im_file).convert("RGB") |
| | output = {} |
| | inputs = self.blip2_proc(image, return_tensors="pt").to(device, torch.float16) |
| | #nucleus vs beam sampling |
| | if decode == None or decode == "beam": |
| | generated_ids = self.blip2.generate(**inputs, max_new_tokens=20) |
| | prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() |
| | #english vs german caption |
| | if lang != None or lang == "de": |
| | translation = self.translator(prediction) |
| | output["beam"] = translation[0] |
| | else: |
| | output["beam"] = prediction |
| | if decode != None or decode == "nucleus": |
| | generated_ids = self.blip2.generate(**inputs, max_new_tokens=20) |
| | prediction = self.blip2_proc.batch_decode(generated_ids, skip_special_tokens=True,do_sample=True)[0].strip() |
| | #english vs german caption |
| | if lang != None or lang == "de": |
| | translation = self.translator(prediction) |
| | output["nucleus"] = translation[0] |
| | else: |
| | output["nucleus"] = prediction |
| | |
| | # postprocess the prediction |
| | return output""" |
| | return 73 |
| |
|