ubden commited on
Commit
44e300a
·
verified ·
1 Parent(s): 2804fb1

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +144 -48
handler.py CHANGED
@@ -100,30 +100,73 @@ class EndpointHandler:
100
  Returns:
101
  PIL Image object or None if something goes wrong
102
  """
 
 
 
 
103
  try:
104
  # Check if it's a URL (starts with http/https)
105
- if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
106
  print(f"🌐 Fetching image from URL: {image_input[:50]}...")
107
- response = requests.get(image_input, timeout=10)
 
 
 
108
  response.raise_for_status()
 
 
 
 
 
109
  image = Image.open(BytesIO(response.content)).convert('RGB')
110
- print("✅ Image downloaded successfully!")
111
  return image
112
 
113
- # Must be base64 then
114
- elif isinstance(image_input, str):
115
- print("🔍 Decoding base64 image...")
116
- # Remove the data URL prefix if it exists
117
- if "base64," in image_input:
118
- image_input = image_input.split("base64,")[1]
119
-
120
- image_data = base64.b64decode(image_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  image = Image.open(BytesIO(image_data)).convert('RGB')
122
- print("✅ Image decoded successfully!")
123
  return image
124
 
 
 
 
 
 
 
125
  except Exception as e:
126
- print(f"❌ Couldn't process the image: {e}")
127
  return None
128
 
129
  return None
@@ -154,15 +197,20 @@ class EndpointHandler:
154
 
155
  if isinstance(inputs, dict):
156
  # Dictionary input - check for text and image
157
- text = inputs.get("text", inputs.get("prompt", str(inputs)))
 
158
 
159
  # Check for image in various formats
160
  image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
161
  if image_input:
162
  image = self.process_image_input(image_input)
163
  if image:
164
- # For now, we'll add a note about the image since we're text-only
165
- text = f"[Image provided - {image.size[0]}x{image.size[1]} pixels] {text}"
 
 
 
 
166
  else:
167
  # Simple string input
168
  text = str(inputs)
@@ -172,29 +220,52 @@ class EndpointHandler:
172
 
173
  # Get generation parameters with sensible defaults
174
  parameters = data.get("parameters", {})
175
- max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
176
- temperature = parameters.get("temperature", 0.7)
177
- top_p = parameters.get("top_p", 0.95)
178
- do_sample = parameters.get("do_sample", True)
179
- repetition_penalty = parameters.get("repetition_penalty", 1.0)
 
 
 
 
180
 
181
  # Using pipeline? Let's go!
182
  if self.use_pipeline:
183
- result = self.pipe(
184
- text,
185
- max_new_tokens=max_new_tokens,
186
- temperature=temperature,
187
- top_p=top_p,
188
- do_sample=do_sample,
189
- repetition_penalty=repetition_penalty,
190
- return_full_text=False # Just the new stuff, not the input
191
- )
192
 
193
- # Pipeline returns a list, let's handle it
 
 
 
 
 
 
194
  if isinstance(result, list) and len(result) > 0:
195
- return [{"generated_text": result[0].get("generated_text", "")}]
 
 
 
 
 
 
 
 
 
 
196
  else:
197
- return [{"generated_text": str(result)}]
 
 
 
 
198
 
199
  # Manual generation mode
200
  else:
@@ -203,7 +274,7 @@ class EndpointHandler:
203
  text,
204
  return_tensors="pt",
205
  truncation=True,
206
- max_length=2048
207
  )
208
 
209
  input_ids = encoded["input_ids"].to(self.device)
@@ -211,19 +282,33 @@ class EndpointHandler:
211
  if attention_mask is not None:
212
  attention_mask = attention_mask.to(self.device)
213
 
 
 
 
 
 
 
 
 
214
  # Generate the response
215
  with torch.no_grad():
216
- outputs = self.model.generate(
217
- input_ids,
218
- attention_mask=attention_mask,
219
- max_new_tokens=max_new_tokens,
220
- temperature=temperature,
221
- top_p=top_p,
222
- do_sample=do_sample,
223
- repetition_penalty=repetition_penalty,
224
- pad_token_id=self.tokenizer.pad_token_id,
225
- eos_token_id=self.tokenizer.eos_token_id
226
- )
 
 
 
 
 
 
227
 
228
  # Decode only the new tokens (not the input)
229
  generated_ids = outputs[0][input_ids.shape[-1]:]
@@ -233,13 +318,24 @@ class EndpointHandler:
233
  clean_up_tokenization_spaces=True
234
  )
235
 
236
- return [{"generated_text": generated_text}]
 
 
 
 
 
 
 
 
 
237
 
238
  except Exception as e:
239
- error_msg = f"Something went wrong during generation: {str(e)}"
240
  print(f"❌ {error_msg}")
241
  return [{
242
  "generated_text": "",
243
  "error": error_msg,
244
- "handler": "Ubden® Team Enhanced Handler"
 
 
245
  }]
 
100
  Returns:
101
  PIL Image object or None if something goes wrong
102
  """
103
+ if not image_input or not isinstance(image_input, str):
104
+ print("❌ Invalid image input provided")
105
+ return None
106
+
107
  try:
108
  # Check if it's a URL (starts with http/https)
109
+ if image_input.startswith(('http://', 'https://')):
110
  print(f"🌐 Fetching image from URL: {image_input[:50]}...")
111
+ headers = {
112
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
113
+ }
114
+ response = requests.get(image_input, timeout=15, headers=headers)
115
  response.raise_for_status()
116
+
117
+ # Verify it's actually an image
118
+ if not response.headers.get('content-type', '').startswith('image/'):
119
+ print(f"⚠️ URL doesn't seem to point to an image: {response.headers.get('content-type')}")
120
+
121
  image = Image.open(BytesIO(response.content)).convert('RGB')
122
+ print(f"✅ Image downloaded successfully! Size: {image.size}")
123
  return image
124
 
125
+ # Handle base64 images
126
+ else:
127
+ print("🔍 Processing base64 image...")
128
+ base64_data = image_input
129
+
130
+ # Remove data URL prefix if it exists (data:image/jpeg;base64,...)
131
+ if image_input.startswith('data:'):
132
+ if 'base64,' in image_input:
133
+ base64_data = image_input.split('base64,')[1]
134
+ else:
135
+ print("❌ Invalid data URL format - missing base64 encoding")
136
+ return None
137
+
138
+ # Clean up any whitespace
139
+ base64_data = base64_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
140
+
141
+ # Validate base64 format
142
+ try:
143
+ # Add padding if necessary
144
+ missing_padding = len(base64_data) % 4
145
+ if missing_padding:
146
+ base64_data += '=' * (4 - missing_padding)
147
+
148
+ image_data = base64.b64decode(base64_data, validate=True)
149
+ except Exception as decode_error:
150
+ print(f"❌ Invalid base64 encoding: {decode_error}")
151
+ return None
152
+
153
+ # Verify it's a valid image
154
+ if len(image_data) < 100: # Too small to be a real image
155
+ print("❌ Decoded data too small to be a valid image")
156
+ return None
157
+
158
  image = Image.open(BytesIO(image_data)).convert('RGB')
159
+ print(f"✅ Base64 image decoded successfully! Size: {image.size}")
160
  return image
161
 
162
+ except requests.exceptions.Timeout:
163
+ print("❌ Request timeout - image URL took too long to respond")
164
+ return None
165
+ except requests.exceptions.RequestException as e:
166
+ print(f"❌ Network error while fetching image: {e}")
167
+ return None
168
  except Exception as e:
169
+ print(f"❌ Error processing image: {e}")
170
  return None
171
 
172
  return None
 
197
 
198
  if isinstance(inputs, dict):
199
  # Dictionary input - check for text and image
200
+ # Support multiple text field names: query, text, prompt
201
+ text = inputs.get("query", inputs.get("text", inputs.get("prompt", "")))
202
 
203
  # Check for image in various formats
204
  image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
205
  if image_input:
206
  image = self.process_image_input(image_input)
207
  if image:
208
+ print(f"✅ Image processed successfully: {image.size[0]}x{image.size[1]} pixels")
209
+ # Add image context to the prompt for better processing
210
+ if text:
211
+ text = f"<image>\nUser query: {text}"
212
+ else:
213
+ text = "<image>\nAnalyze this medical image."
214
  else:
215
  # Simple string input
216
  text = str(inputs)
 
220
 
221
  # Get generation parameters with sensible defaults
222
  parameters = data.get("parameters", {})
223
+ max_new_tokens = min(parameters.get("max_new_tokens", 512), 2048) # Increased default
224
+ temperature = max(0.01, min(parameters.get("temperature", 0.2), 2.0)) # Clamp temperature
225
+ top_p = max(0.01, min(parameters.get("top_p", 0.9), 1.0)) # Clamp top_p
226
+ do_sample = parameters.get("do_sample", temperature > 0.01) # Auto-set based on temperature
227
+ repetition_penalty = max(1.0, min(parameters.get("repetition_penalty", 1.05), 2.0)) # Clamp penalty
228
+ stop_sequences = parameters.get("stop", ["</s>"]) # Support stop sequences
229
+ return_full_text = parameters.get("return_full_text", False)
230
+
231
+ print(f"🎛️ Generation params: max_tokens={max_new_tokens}, temp={temperature}, top_p={top_p}, rep_penalty={repetition_penalty}")
232
 
233
  # Using pipeline? Let's go!
234
  if self.use_pipeline:
235
+ generation_kwargs = {
236
+ "max_new_tokens": max_new_tokens,
237
+ "temperature": temperature,
238
+ "top_p": top_p,
239
+ "do_sample": do_sample,
240
+ "repetition_penalty": repetition_penalty,
241
+ "return_full_text": return_full_text
242
+ }
 
243
 
244
+ # Add stop sequences if supported
245
+ if stop_sequences and stop_sequences != ["</s>"]:
246
+ generation_kwargs["stop_sequence"] = stop_sequences[0] # Most pipelines support single stop
247
+
248
+ result = self.pipe(text, **generation_kwargs)
249
+
250
+ # Pipeline returns a list, let's handle it properly
251
  if isinstance(result, list) and len(result) > 0:
252
+ generated_text = result[0].get("generated_text", "")
253
+ # Clean up any stop sequences that might remain
254
+ for stop_seq in stop_sequences:
255
+ if generated_text.endswith(stop_seq):
256
+ generated_text = generated_text[:-len(stop_seq)].rstrip()
257
+
258
+ return [{
259
+ "generated_text": generated_text,
260
+ "model": "PULSE-7B",
261
+ "processing_method": "pipeline"
262
+ }]
263
  else:
264
+ return [{
265
+ "generated_text": str(result),
266
+ "model": "PULSE-7B",
267
+ "processing_method": "pipeline"
268
+ }]
269
 
270
  # Manual generation mode
271
  else:
 
274
  text,
275
  return_tensors="pt",
276
  truncation=True,
277
+ max_length=4096 # Increased context length
278
  )
279
 
280
  input_ids = encoded["input_ids"].to(self.device)
 
282
  if attention_mask is not None:
283
  attention_mask = attention_mask.to(self.device)
284
 
285
+ # Prepare stop token IDs
286
+ stop_token_ids = []
287
+ if stop_sequences:
288
+ for stop_seq in stop_sequences:
289
+ stop_tokens = self.tokenizer.encode(stop_seq, add_special_tokens=False)
290
+ if stop_tokens:
291
+ stop_token_ids.extend(stop_tokens)
292
+
293
  # Generate the response
294
  with torch.no_grad():
295
+ generation_kwargs = {
296
+ "input_ids": input_ids,
297
+ "attention_mask": attention_mask,
298
+ "max_new_tokens": max_new_tokens,
299
+ "temperature": temperature,
300
+ "top_p": top_p,
301
+ "do_sample": do_sample,
302
+ "repetition_penalty": repetition_penalty,
303
+ "pad_token_id": self.tokenizer.pad_token_id,
304
+ "eos_token_id": self.tokenizer.eos_token_id
305
+ }
306
+
307
+ # Add stop token IDs if we have them
308
+ if stop_token_ids:
309
+ generation_kwargs["eos_token_id"] = stop_token_ids + [self.tokenizer.eos_token_id]
310
+
311
+ outputs = self.model.generate(**generation_kwargs)
312
 
313
  # Decode only the new tokens (not the input)
314
  generated_ids = outputs[0][input_ids.shape[-1]:]
 
318
  clean_up_tokenization_spaces=True
319
  )
320
 
321
+ # Clean up any remaining stop sequences
322
+ for stop_seq in stop_sequences:
323
+ if generated_text.endswith(stop_seq):
324
+ generated_text = generated_text[:-len(stop_seq)].rstrip()
325
+
326
+ return [{
327
+ "generated_text": generated_text.strip(),
328
+ "model": "PULSE-7B",
329
+ "processing_method": "manual"
330
+ }]
331
 
332
  except Exception as e:
333
+ error_msg = f"Generation error: {str(e)}"
334
  print(f"❌ {error_msg}")
335
  return [{
336
  "generated_text": "",
337
  "error": error_msg,
338
+ "model": "PULSE-7B",
339
+ "handler": "Ubden® Team Enhanced Handler",
340
+ "success": False
341
  }]