BoghdadyJR commited on
Commit
c607d34
·
verified ·
1 Parent(s): f8ddc4c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +144 -0
handler.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor
3
+ from PIL import Image
4
+ import torch
5
+ import io
6
+ import base64
7
+ from peft import PeftModel
8
+
9
+ class EndpointHandler():
10
+ def __init__(self, model_dir: str):
11
+ self.path = model_dir
12
+ # Load base model and tokenizer
13
+ base_model_id = "Qwen/Qwen2-VL-2B-Instruct"
14
+
15
+ # Load tokenizer/processor
16
+ self.processor = AutoProcessor.from_pretrained(
17
+ self.path,
18
+ trust_remote_code=True
19
+ )
20
+
21
+ # Load base model
22
+ self.model = AutoModelForVision2Seq.from_pretrained(
23
+ base_model_id,
24
+ torch_dtype=torch.float16,
25
+ device_map="auto",
26
+ trust_remote_code=True
27
+ )
28
+
29
+ # Load LoRA adapter
30
+ self.model = PeftModel.from_pretrained(
31
+ self.model,
32
+ self.path,
33
+ device_map="auto"
34
+ )
35
+
36
+ # Merge and unload for faster inference
37
+ self.model = self.model.merge_and_unload()
38
+
39
+ # Set to eval mode
40
+ self.model.eval()
41
+
42
+ # Store the instruction template
43
+ self.instruction = """
44
+ A conversation between a Healthcare Provider and an AI Medical Image Analysis Assistant. The provider shares a medical image, and the Assistant generates a clear description/report. The assistant first analyzes the image systematically, then provides a concise report. The analysis process and report are enclosed within <thinking> </thinking><answer> </answer>.
45
+ Always respond in this format:
46
+ <thinking>
47
+ 1. Initial Assessment:
48
+ - What type of image is this? (X-ray, CT, MRI, etc.)
49
+ - Which body part/region is shown?
50
+ - Is the image quality adequate?
51
+ 2. Key Findings:
52
+ - What are the normal structures visible?
53
+ - Are there any abnormalities?
54
+ - What are the important measurements?
55
+ 3. Clinical Significance:
56
+ - What are the main clinical findings?
57
+ - Are there any critical findings?
58
+ </thinking>
59
+ <answer>
60
+ Brief Structured Report:
61
+ 1. EXAM TYPE: [imaging type and body region]
62
+ 2. FINDINGS: [key observations and abnormalities]
63
+ 3. IMPRESSION: [summary and clinical significance]
64
+ </answer>
65
+ """
66
+
67
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
68
+ """
69
+ data args:
70
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
71
+ parameters (:obj: `Dict[str, Any]`, *optional*)
72
+ Return:
73
+ A :obj:`list` | `dict`: will be serialized and returned
74
+ """
75
+ # Extract inputs and parameters
76
+ inputs = data.pop("inputs", data)
77
+ parameters = data.pop("parameters", {})
78
+
79
+ # Handle different input formats
80
+ if isinstance(inputs, str):
81
+ # Base64 encoded image
82
+ image_bytes = base64.b64decode(inputs)
83
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
84
+ elif isinstance(inputs, dict):
85
+ # Dictionary with image key
86
+ image_data = inputs.get("image", inputs.get("inputs", ""))
87
+ if isinstance(image_data, str):
88
+ image_bytes = base64.b64decode(image_data)
89
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
90
+ else:
91
+ image = image_data
92
+ else:
93
+ # Direct image
94
+ image = inputs
95
+
96
+ # Ensure image is RGB
97
+ if image.mode != "RGB":
98
+ image = image.convert("RGB")
99
+
100
+ # Prepare messages in Qwen format
101
+ messages = [
102
+ {
103
+ "role": "user",
104
+ "content": [
105
+ {"type": "image", "image": image},
106
+ {"type": "text", "text": self.instruction}
107
+ ]
108
+ }
109
+ ]
110
+
111
+ # Process inputs
112
+ text = self.processor.apply_chat_template(
113
+ messages,
114
+ tokenize=False,
115
+ add_generation_prompt=True
116
+ )
117
+
118
+ # Prepare inputs for model
119
+ inputs = self.processor(
120
+ text=[text],
121
+ images=[image],
122
+ padding=True,
123
+ return_tensors="pt"
124
+ ).to(self.model.device)
125
+
126
+ # Generate response
127
+ with torch.no_grad():
128
+ output_ids = self.model.generate(
129
+ **inputs,
130
+ max_new_tokens=parameters.get("max_new_tokens", 512),
131
+ temperature=parameters.get("temperature", 0.7),
132
+ top_p=parameters.get("top_p", 0.9),
133
+ do_sample=True,
134
+ pad_token_id=self.processor.tokenizer.pad_token_id,
135
+ eos_token_id=self.processor.tokenizer.eos_token_id,
136
+ )
137
+
138
+ # Decode output - only the generated part
139
+ output_text = self.processor.batch_decode(
140
+ output_ids[:, inputs.input_ids.shape[1]:],
141
+ skip_special_tokens=True
142
+ )[0]
143
+
144
+ return [{"generated_text": output_text}]