BoghdadyJR commited on
Commit
1471372
·
verified ·
1 Parent(s): ee1fa58

Add inference handler for HF Endpoints

Browse files
Files changed (2) hide show
  1. handler.py +95 -68
  2. requirements.txt +6 -5
handler.py CHANGED
@@ -1,20 +1,43 @@
1
  from typing import Dict, List, Any
2
- from unsloth import FastVisionModel
3
  from PIL import Image
4
  import torch
 
 
 
5
 
6
  class EndpointHandler():
7
- def __init__(self, path=""):
8
- # Load model and tokenizer
9
- self.model, self.tokenizer = FastVisionModel.from_pretrained(
10
- path,
 
 
 
 
 
 
 
 
 
 
 
11
  device_map="auto",
12
- load_in_4bit=False, # Use 4bit to reduce memory use. False for 16bit LoRA
13
- use_gradient_checkpointing="unsloth", # True or "unsloth" for long context
14
  )
15
 
16
- # Enable for inference
17
- FastVisionModel.for_inference(self.model)
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Store the instruction template
20
  self.instruction = """
@@ -46,77 +69,81 @@ Brief Structured Report:
46
  </answer>
47
  """
48
 
49
- def __call__(self, data: Any) -> List[Dict[str, Any]]:
50
  """
51
- Args:
52
- data (:obj:):
53
- includes the input data and the parameters for the inference.
54
- Expected format:
55
- {
56
- "inputs": {
57
- "image": PIL.Image object,
58
- "instruction": optional_custom_instruction
59
- },
60
- "parameters": {
61
- "max_new_tokens": 512,
62
- "temperature": 0.7,
63
- "top_p": 0.9,
64
- ...
65
- }
66
- }
67
  Return:
68
- A :obj:`list`:. The list contains a dictionary with:
69
- - "generated_text": The model's response
70
  """
 
71
  inputs = data.pop("inputs", data)
72
  parameters = data.pop("parameters", {})
73
 
74
- # Extract image and instruction
75
- image = inputs.get("image")
76
- custom_instruction = inputs.get("instruction", self.instruction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # Prepare messages
79
  messages = [
80
- {"role": "user", "content": [
81
- {"type": "image"},
82
- {"type": "text", "text": custom_instruction}
83
- ]}
 
 
 
84
  ]
85
 
86
- # Apply chat template
87
- input_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True)
 
 
 
 
88
 
89
- # Tokenize inputs
90
- model_inputs = self.tokenizer(
91
- image,
92
- input_text,
93
- add_special_tokens=False,
94
- return_tensors="pt",
95
  ).to(self.model.device)
96
 
97
- # Set default parameters
98
- generation_params = {
99
- "max_new_tokens": parameters.get("max_new_tokens", 512),
100
- "temperature": parameters.get("temperature", 0.7),
101
- "top_p": parameters.get("top_p", 0.9),
102
- "min_p": parameters.get("min_p", 0.1),
103
- "use_cache": True,
104
- "do_sample": parameters.get("do_sample", True),
105
- "repetition_penalty": parameters.get("repetition_penalty", 1.1),
106
- }
107
-
108
- output_ids = self.model.generate(
109
- **model_inputs,
110
- **generation_params
111
- )
112
-
113
-
114
- # Decode output
115
- generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
116
 
117
- # Extract only the generated response (remove the prompt)
118
- response = generated_text.split(custom_instruction)[-1].strip()
 
 
 
119
 
120
- return [{
121
- "generated_text": response
122
- }]
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
3
  from PIL import Image
4
  import torch
5
+ import io
6
+ import base64
7
+ from peft import PeftModel
8
 
9
  class EndpointHandler():
10
+ def __init__(self):
11
+ self.path = "BoghdadyJR/Qwen_UI_final"
12
+ # Load base model and tokenizer
13
+ base_model_id = "Qwen/Qwen2-VL-2B-Instruct"
14
+
15
+ # Load tokenizer/processor
16
+ self.processor = AutoProcessor.from_pretrained(
17
+ self.path,
18
+ trust_remote_code=True
19
+ )
20
+
21
+ # Load base model
22
+ self.model = AutoModelForVision2Seq.from_pretrained(
23
+ base_model_id,
24
+ torch_dtype=torch.float16,
25
  device_map="auto",
26
+ trust_remote_code=True
 
27
  )
28
 
29
+ # Load LoRA adapter
30
+ self.model = PeftModel.from_pretrained(
31
+ self.model,
32
+ self.path,
33
+ device_map="auto"
34
+ )
35
+
36
+ # Merge and unload for faster inference
37
+ self.model = self.model.merge_and_unload()
38
+
39
+ # Set to eval mode
40
+ self.model.eval()
41
 
42
  # Store the instruction template
43
  self.instruction = """
 
69
  </answer>
70
  """
71
 
72
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
73
  """
74
+ data args:
75
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
76
+ parameters (:obj: `Dict[str, Any]`, *optional*)
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  Return:
78
+ A :obj:`list` | `dict`: will be serialized and returned
 
79
  """
80
+ # Extract inputs and parameters
81
  inputs = data.pop("inputs", data)
82
  parameters = data.pop("parameters", {})
83
 
84
+ # Handle different input formats
85
+ if isinstance(inputs, str):
86
+ # Base64 encoded image
87
+ image_bytes = base64.b64decode(inputs)
88
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
89
+ elif isinstance(inputs, dict):
90
+ # Dictionary with image key
91
+ image_data = inputs.get("image", inputs.get("inputs", ""))
92
+ if isinstance(image_data, str):
93
+ image_bytes = base64.b64decode(image_data)
94
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
95
+ else:
96
+ image = image_data
97
+ else:
98
+ # Direct image
99
+ image = inputs
100
+
101
+ # Ensure image is RGB
102
+ if image.mode != "RGB":
103
+ image = image.convert("RGB")
104
 
105
+ # Prepare messages in Qwen format
106
  messages = [
107
+ {
108
+ "role": "user",
109
+ "content": [
110
+ {"type": "image", "image": image},
111
+ {"type": "text", "text": self.instruction}
112
+ ]
113
+ }
114
  ]
115
 
116
+ # Process inputs
117
+ text = self.processor.apply_chat_template(
118
+ messages,
119
+ tokenize=False,
120
+ add_generation_prompt=True
121
+ )
122
 
123
+ # Prepare inputs for model
124
+ inputs = self.processor(
125
+ text=[text],
126
+ images=[image],
127
+ padding=True,
128
+ return_tensors="pt"
129
  ).to(self.model.device)
130
 
131
+ # Generate response
132
+ with torch.no_grad():
133
+ output_ids = self.model.generate(
134
+ **inputs,
135
+ max_new_tokens=parameters.get("max_new_tokens", 512),
136
+ temperature=parameters.get("temperature", 0.7),
137
+ top_p=parameters.get("top_p", 0.9),
138
+ do_sample=True,
139
+ pad_token_id=self.processor.tokenizer.pad_token_id,
140
+ eos_token_id=self.processor.tokenizer.eos_token_id,
141
+ )
 
 
 
 
 
 
 
 
142
 
143
+ # Decode output - only the generated part
144
+ output_text = self.processor.batch_decode(
145
+ output_ids[:, inputs.input_ids.shape[1]:],
146
+ skip_special_tokens=True
147
+ )[0]
148
 
149
+ return [{"generated_text": output_text}]
 
 
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- torch>=2.0.0
2
- transformers>=4.36.0
3
- unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
4
- Pillow>=9.0.0
5
- accelerate>=0.25.0
 
 
1
+ transformers==4.52.4
2
+ accelerate==1.7.0
3
+ peft==0.15.2
4
+ pillow==11.2.1
5
+ torch==2.7.1
6
+ torchvision=0.22.1