Portx commited on
Commit
2ea4855
·
verified ·
1 Parent(s): 7e2c4ad

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +117 -31
handler.py CHANGED
@@ -1,48 +1,134 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # check for GPU
5
  device = 0 if torch.cuda.is_available() else -1
6
-
7
- # multi-model list
8
- multi_model_list = [
9
- {"model_id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"},
10
- {"model_id": "Helsinki-NLP/opus-mt-en-de", "task": "translation"},
11
- {"model_id": "facebook/bart-large-cnn", "task": "summarization"},
12
- {"model_id": "dslim/bert-base-NER", "task": "token-classification"},
13
- {"model_id": "textattack/bert-base-uncased-ag-news", "task": "text-classification"},
14
- {"model_id": "openai-community/gpt2","task": "text-generation"}
15
- ]
16
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class EndpointHandler():
18
  def __init__(self, path=""):
19
- self.multi_model={}
20
- # load all the models onto device
21
- for model in multi_model_list:
22
- self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device)
23
-
24
  def __call__(self, data):
25
  # deserialize incomin request
26
  inputs = data.pop("inputs", data)
27
  parameters = data.pop("parameters", None)
28
- model_id = data.pop("model_id", None)
29
  prompt_id = data.pop("prompt_id", None)
 
 
 
30
 
 
31
  if prompt_id==1:
32
- inputs="What is the ai in 3 sentences?"
33
  elif prompt_id==2:
34
- inputs="Who is elon musk?"
 
 
35
  else:
36
  pass
37
 
38
- # check if model_id is in the list of models
39
- if model_id is None or model_id not in self.multi_model:
40
- raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
41
-
42
- # pass inputs with all kwargs in data
43
- if parameters is not None:
44
- prediction = self.multi_model[model_id](inputs, **parameters)
45
- else:
46
- prediction = self.multi_model[model_id](inputs)
47
- # postprocess the prediction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return prediction
 
1
  import torch
2
+ from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
3
+ import sys
4
+ from subprocess import run
5
+ from PIL import Image
6
+
7
+ run("pip install flash-attn --no-build-isolation", shell=True, check=True)
8
+ model_id = "ibm-granite/granite-vision-3.1-2b-preview"
9
+
10
+ bnb_config = BitsAndBytesConfig(
11
+ load_in_4bit=True,
12
+ bnb_4bit_use_double_quant=True,
13
+ bnb_4bit_quant_type="nf4",
14
+ bnb_4bit_compute_dtype=torch.bfloat16,
15
+ llm_int8_skip_modules=["vision_tower", "lm_head"],
16
+ llm_int8_enable_fp32_cpu_offload=True
17
+ )
18
+
19
+
20
+ try:
21
+ import flash_attn
22
+ print("FlashAttention is installed")
23
+ USE_FLASH_ATTENTION = True
24
+ except ImportError:
25
+ print("FlashAttention is not installed")
26
+ USE_FLASH_ATTENTION = False
27
+
28
+
29
  # check for GPU
30
  device = 0 if torch.cuda.is_available() else -1
31
+
32
+
33
+
34
+ inputs = processor.apply_chat_template(
35
+ conversation,
36
+ add_generation_prompt=True,
37
+ tokenize=True,
38
+ return_dict=True,
39
+ return_tensors="pt"
40
+ ).to(device)
41
+
42
+
43
+ class Utils:
44
+ def convert_base64_to_jpg(base64_string):
45
+ image_data = base64.b64decode(base64_string)
46
+ with open("./do_img.jpg", 'wb') as f:
47
+ f.write(image_data)
48
+
49
+
50
+ class PromptSet:
51
+ system_message = "You are an expert in analyzing and extracting information from freight, shipment, or delivery orders. Please carefully read the provided order file and extract the following 10 key pieces of information. Ensure that the key names are exactly as listed below. Do not create any additional key names other than these. If any information is missing or unavailable, output '-'."
52
+ main_order_information_prompt = """
53
+ #Key names and their descriptions:
54
+ 1. container_number: The container number/no of the shipment (e.g., TRKU2038448, MSDU8549321). This should be an 11-character container number, with no additional format. If not available, output '-'.
55
+ 2. bill_of_lading: The Bill of Lading number, which could include formats such as B/L No., AWS No., BL No., or ocean Bill of Lading (e.g., AXVJMER000008166, TRKU-10152009, HLCU ALY241000275). If not available, output '-'.
56
+ 3. importing_carrier: The importing or ocean carrier, which may include SCAC codes, carrier's local agents, or sea line codes. If not available, output '-'.
57
+ 4. origin_address: The address for picking up the container, such as the origin address, pickup location, terminal, or port of discharge. Exclude loading location information. (e.g., "PORT LIBERTY NY CONTAINER TERMINAL 300 WESTERN AVE"). If not available, output '-'.
58
+ 5. destination_address: The address where the container is to be delivered, typically a company name or a specific delivery location (e.g., "AERO RECEIVING EAST, 2 BRICK PLANT ROAD, SOUTH RIVER, NJ"). If not available, output '-'.
59
+ 6. container_weight: The weight of the container (in numeric format, e.g., 58,201.44). If there are multiple weights, output the highest value. If not available, output '-'.
60
+ 7. container_weight_unit: The unit of measurement for the container's weight (e.g., LBS, KGS, KG, LB). If not available, output '-'.
61
+ 8. container_type: The type/size of the container (e.g., 40HC, 20GP FCL). If not available, output '-'.
62
+ 9. po_number: The purchase order number or customer’s PO (e.g., PO Number, customer’s PO, consol). If not available, output '-'.
63
+ 10. reference_number: The reference number, file number, or any internal reference (e.g., reference number, our ref no.). If not available, output '-'.
64
+ #Output:
65
+ {container_number: ...,
66
+ bill_of_lading: ..,
67
+ importing_carrier: ...,
68
+ origin_address: ...,
69
+ destination_address: ...,
70
+ container_weight: ...,
71
+ container_weight_unit: ...,
72
+ container_type: ...,
73
+ po_number: ...,
74
+ reference_number: ...
75
+ }
76
+ Guidelines:
77
+ - Very important: do not make up anything. If the information of a required field is not available, output '-' for it.
78
+ - Output in JSON format. The JSON should contain the above 10 keys.
79
+ """
80
+ order_list_prompt = "How much container are there? Give to me all container numbers only in a json array?"
81
+ multiple_container_information_prompt = "Give to me container weight, container weight unit,the container size (with type) of {query} in the same line with container_number:{query}.You must response only in a JSON format. Example output is must be 'container_number': 'OOCU6979480', 'container_type': '40HC or DV', 'weight': '46,737.52', 'weight_unit': 'LB'"
82
+
83
+
84
  class EndpointHandler():
85
  def __init__(self, path=""):
86
+ self.model=AutoModelForVision2Seq.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16,
87
+ quantization_config=bnb_config,
88
+ _attn_implementation="flash_attention_2" if USE_FLASH_ATTENTION else None)
89
+ self.processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
90
+
91
  def __call__(self, data):
92
  # deserialize incomin request
93
  inputs = data.pop("inputs", data)
94
  parameters = data.pop("parameters", None)
 
95
  prompt_id = data.pop("prompt_id", None)
96
+ base64_image = data.pop("image", None)
97
+
98
+ converted_image = Utils.convert_base64_to_jpg(base64_image)
99
 
100
+
101
  if prompt_id==1:
102
+ final_prompt=PromptSet.main_order_information_prompt
103
  elif prompt_id==2:
104
+ final_prompt=PromptSet.order_list_prompt
105
+ elif prompt_id==3:
106
+ final_prompt=PromptSet.multiple_container_information_prompt
107
  else:
108
  pass
109
 
110
+
111
+
112
+ conversation = [{
113
+ "role": "system",
114
+ "content": [
115
+ {
116
+ "type": "text",
117
+ "text": PromptSet.system_message
118
+ }
119
+ ],
120
+ },{
121
+ "role": "user",
122
+ "content": [
123
+ {"type": "image", "url": "./do_img.jpg"},
124
+ {"type": "text", "text": final_prompt},
125
+ ],},
126
+ ]
127
+
128
+ model_inputs = self.processor.apply_chat_template(conversation,add_generation_prompt=True,
129
+ tokenize=True, return_dict=True,return_tensors="pt").to(device)
130
+
131
+
132
+ output = model.generate(**model_inputs, max_new_tokens=512)
133
+ prediction = processor.decode(output[0], skip_special_tokens=True)
134
  return prediction