| import torch |
| import sys |
| from subprocess import run |
| from PIL import Image |
| import os |
| import base64 |
|
|
| |
| run("pip install --upgrade pip", shell=True, check=True) |
| run("pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124", shell=True, check=True) |
|
|
|
|
| from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig |
|
|
|
|
|
|
| model_id = "ibm-granite/granite-vision-3.2-2b" |
| adapter_id = "Portx/granite-vision-3.2-2b-20250403-full" |
|
|
| |
| device = 0 if torch.cuda.is_available() else -1 |
|
|
| class Utils: |
| def convert_base64_to_jpg(base64_string): |
| image_data = base64.b64decode(base64_string) |
| with open("./do_img.jpg", 'wb') as f: |
| f.write(image_data) |
|
|
| class PromptSet: |
| system_message = "You are an expert in analyzing and extracting information from freight, shipment, or delivery orders. Please carefully read the provided order file and extract the following 10 key pieces of information. Ensure that the key names are exactly as listed below. Do not create any additional key names other than these. If any information is missing or unavailable, output '-'." |
| main_order_information_prompt = """Extract the order document. |
| #Output: |
| {container_number: ..., |
| bill_of_lading: .., |
| importing_carrier: ..., |
| origin_address: ..., |
| destination_address: ..., |
| container_weight: ..., |
| container_weight_unit: ..., |
| container_type: ..., |
| po_number: ..., |
| reference_number: ... |
| } |
| Guidelines: |
| - Very important: do not make up anything. If the information of a required field is not available, output '-' for it. |
| - Output in JSON format. The JSON should contain the above 10 keys. |
| """ |
| order_list_prompt = "How much container are there? Give to me all container numbers only in a json array?" |
| multiple_container_information_prompt = "Give to me container weight, container weight unit,the container size (with type) of {query} in the same line with container_number:{query}.You must response only in a JSON format. Example output is must be 'container_number': 'OOCU6979480', 'container_type': '40HC or DV', 'weight': '46,737.52', 'weight_unit': 'LB'" |
|
|
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| self.model=AutoModelForVision2Seq.from_pretrained(model_id, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
| self.model.load_adapter(adapter_id) |
| |
| self.processor = AutoProcessor.from_pretrained(model_id, |
| use_fast=True, |
| trust_remote_code=True) |
|
|
| def __call__(self, data): |
| |
| inputs = data.pop("inputs", data) |
| parameters = data.pop("parameters", None) |
| prompt_id = data.pop("prompt_id", None) |
| base64_image = data.pop("image", None) |
|
|
| converted_image = Utils.convert_base64_to_jpg(base64_image) |
|
|
| |
| if prompt_id==1: |
| final_prompt=PromptSet.main_order_information_prompt |
| elif prompt_id==2: |
| final_prompt=PromptSet.order_list_prompt |
| elif prompt_id==3: |
| final_prompt=PromptSet.multiple_container_information_prompt |
| else: |
| final_prompt=inputs |
| |
|
|
|
|
| conversation = [{ |
| "role": "system", |
| "content": [ |
| { |
| "type": "text", |
| "text": PromptSet.system_message |
| } |
| ], |
| },{ |
| "role": "user", |
| "content": [ |
| {"type": "image", "url": "./do_img.jpg"}, |
| {"type": "text", "text": final_prompt}, |
| ],}, |
| ] |
|
|
| model_inputs = self.processor.apply_chat_template(conversation,add_generation_prompt=True, |
| tokenize=True, return_dict=True,return_tensors="pt").to(device) |
| |
|
|
| output = self.model.generate(**model_inputs, max_new_tokens=512) |
| prediction = self.processor.decode(output[0], skip_special_tokens=True) |
| return prediction |