ubden commited on
Commit
775dded
·
verified ·
1 Parent(s): b13fdfd

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +103 -39
handler.py CHANGED
@@ -1,27 +1,38 @@
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, List, Any
 
 
 
 
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  """
8
- Initialize the handler for PULSE-7B model.
9
- Direct reference to the original model.
10
 
11
  Args:
12
- path: Path to the model directory (not used, we load from HF hub)
13
  """
14
- print("Initializing PULSE-7B handler...")
 
15
 
16
- # Device ayarla
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
- print(f"Using device: {self.device}")
19
 
20
  try:
21
- # Pipeline kullan - en basit ve güvenilir yöntem
22
  from transformers import pipeline
23
 
24
- print("Loading model from HuggingFace Hub...")
25
  self.pipe = pipeline(
26
  "text-generation",
27
  model="PULSE-ECG/PULSE-7B",
@@ -33,25 +44,25 @@ class EndpointHandler:
33
  "use_safetensors": True
34
  }
35
  )
36
- print("Model loaded successfully via pipeline!")
37
 
38
  except Exception as e:
39
- print(f"Pipeline loading failed: {e}")
40
- print("Trying alternative loading method...")
41
 
42
  try:
43
- # Alternatif: Model ve tokenizer ayrı yükle
44
  from transformers import AutoTokenizer, LlamaForCausalLM
45
 
46
- # Tokenizer'ı yükle
47
- print("Loading tokenizer...")
48
  self.tokenizer = AutoTokenizer.from_pretrained(
49
  "PULSE-ECG/PULSE-7B",
50
  trust_remote_code=True
51
  )
52
 
53
- # Model'i Llama olarak yükle
54
- print("Loading model as Llama...")
55
  self.model = LlamaForCausalLM.from_pretrained(
56
  "PULSE-ECG/PULSE-7B",
57
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -60,18 +71,18 @@ class EndpointHandler:
60
  trust_remote_code=True
61
  )
62
 
63
- # Padding token ayarla
64
  if self.tokenizer.pad_token is None:
65
  self.tokenizer.pad_token = self.tokenizer.eos_token
66
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
67
 
68
  self.model.eval()
69
  self.use_pipeline = False
70
- print("Model loaded successfully via direct loading!")
71
 
72
  except Exception as e2:
73
- print(f"Alternative loading also failed: {e2}")
74
- # En son çare: Basit bir fallback mesajı
75
  self.pipe = None
76
  self.model = None
77
  self.tokenizer = None
@@ -79,35 +90,87 @@ class EndpointHandler:
79
  else:
80
  self.use_pipeline = True
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
83
  """
84
- Process the inference request.
85
 
86
  Args:
87
- data: Input data containing 'inputs' and optional 'parameters'
88
 
89
  Returns:
90
- List containing the generated response
91
  """
92
- # Model yüklenemediyse hata döndür
93
  if self.use_pipeline is None:
94
  return [{
95
- "generated_text": "Model could not be loaded. Please check the deployment configuration.",
96
- "error": "Model initialization failed"
 
97
  }]
98
 
99
  try:
100
- # Input'ları al
101
  inputs = data.get("inputs", "")
 
 
 
102
  if isinstance(inputs, dict):
 
103
  text = inputs.get("text", inputs.get("prompt", str(inputs)))
 
 
 
 
 
 
 
 
104
  else:
 
105
  text = str(inputs)
106
 
107
  if not text:
108
- return [{"generated_text": "Please provide an input text."}]
109
 
110
- # Parametreleri al
111
  parameters = data.get("parameters", {})
112
  max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
113
  temperature = parameters.get("temperature", 0.7)
@@ -115,7 +178,7 @@ class EndpointHandler:
115
  do_sample = parameters.get("do_sample", True)
116
  repetition_penalty = parameters.get("repetition_penalty", 1.0)
117
 
118
- # Pipeline kullanıyorsak
119
  if self.use_pipeline:
120
  result = self.pipe(
121
  text,
@@ -124,18 +187,18 @@ class EndpointHandler:
124
  top_p=top_p,
125
  do_sample=do_sample,
126
  repetition_penalty=repetition_penalty,
127
- return_full_text=False # Sadece yeni üretilen metni döndür
128
  )
129
 
130
- # Pipeline list döndürür
131
  if isinstance(result, list) and len(result) > 0:
132
  return [{"generated_text": result[0].get("generated_text", "")}]
133
  else:
134
  return [{"generated_text": str(result)}]
135
 
136
- # Manuel generation kullanıyorsak
137
  else:
138
- # Tokenize
139
  encoded = self.tokenizer(
140
  text,
141
  return_tensors="pt",
@@ -148,7 +211,7 @@ class EndpointHandler:
148
  if attention_mask is not None:
149
  attention_mask = attention_mask.to(self.device)
150
 
151
- # Generate
152
  with torch.no_grad():
153
  outputs = self.model.generate(
154
  input_ids,
@@ -162,7 +225,7 @@ class EndpointHandler:
162
  eos_token_id=self.tokenizer.eos_token_id
163
  )
164
 
165
- # Decode - sadece yeni tokenleri al
166
  generated_ids = outputs[0][input_ids.shape[-1]:]
167
  generated_text = self.tokenizer.decode(
168
  generated_ids,
@@ -173,9 +236,10 @@ class EndpointHandler:
173
  return [{"generated_text": generated_text}]
174
 
175
  except Exception as e:
176
- error_msg = f"Error during generation: {str(e)}"
177
- print(error_msg)
178
  return [{
179
  "generated_text": "",
180
- "error": error_msg
 
181
  }]
 
1
+ """
2
+ PULSE-7B Enhanced Handler
3
+ Ubden® Team - Edited by https://github.com/ck-cankurt
4
+ Support: Text, Image URLs, and Base64 encoded images
5
+ """
6
+
7
  import torch
8
  from typing import Dict, List, Any
9
+ import base64
10
+ from io import BytesIO
11
+ from PIL import Image
12
+ import requests
13
 
14
 
15
  class EndpointHandler:
16
  def __init__(self, path=""):
17
  """
18
+ Hey there! Let's get this PULSE-7B model up and running.
19
+ We'll load it from the HuggingFace hub directly, so no worries about local files.
20
 
21
  Args:
22
+ path: Model directory path (we actually ignore this and load from HF hub)
23
  """
24
+ print("🚀 Starting up PULSE-7B handler...")
25
+ print("📝 Enhanced by Ubden® Team - github.com/ck-cankurt")
26
 
27
+ # Let's see what hardware we're working with
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ print(f"🖥️ Running on: {self.device}")
30
 
31
  try:
32
+ # First attempt - using pipeline (easiest and most stable way)
33
  from transformers import pipeline
34
 
35
+ print("📦 Fetching model from HuggingFace Hub...")
36
  self.pipe = pipeline(
37
  "text-generation",
38
  model="PULSE-ECG/PULSE-7B",
 
44
  "use_safetensors": True
45
  }
46
  )
47
+ print("Model loaded successfully via pipeline!")
48
 
49
  except Exception as e:
50
+ print(f"⚠️ Pipeline didn't work out: {e}")
51
+ print("🔄 Let me try a different approach...")
52
 
53
  try:
54
+ # Plan B - load model and tokenizer separately
55
  from transformers import AutoTokenizer, LlamaForCausalLM
56
 
57
+ # Get the tokenizer ready
58
+ print("📖 Setting up tokenizer...")
59
  self.tokenizer = AutoTokenizer.from_pretrained(
60
  "PULSE-ECG/PULSE-7B",
61
  trust_remote_code=True
62
  )
63
 
64
+ # Load the model as Llama (it works, trust me!)
65
+ print("🧠 Loading the model as Llama...")
66
  self.model = LlamaForCausalLM.from_pretrained(
67
  "PULSE-ECG/PULSE-7B",
68
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
71
  trust_remote_code=True
72
  )
73
 
74
+ # Quick fix for padding token if it's missing
75
  if self.tokenizer.pad_token is None:
76
  self.tokenizer.pad_token = self.tokenizer.eos_token
77
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
78
 
79
  self.model.eval()
80
  self.use_pipeline = False
81
+ print("Model loaded successfully via direct loading!")
82
 
83
  except Exception as e2:
84
+ print(f"😓 That didn't work either: {e2}")
85
+ # If all else fails, we'll handle it gracefully
86
  self.pipe = None
87
  self.model = None
88
  self.tokenizer = None
 
90
  else:
91
  self.use_pipeline = True
92
 
93
+ def process_image_input(self, image_input):
94
+ """
95
+ Handle both URL and base64 image inputs like a champ!
96
+
97
+ Args:
98
+ image_input: Can be a URL string or base64 encoded image
99
+
100
+ Returns:
101
+ PIL Image object or None if something goes wrong
102
+ """
103
+ try:
104
+ # Check if it's a URL (starts with http/https)
105
+ if isinstance(image_input, str) and (image_input.startswith('http://') or image_input.startswith('https://')):
106
+ print(f"🌐 Fetching image from URL: {image_input[:50]}...")
107
+ response = requests.get(image_input, timeout=10)
108
+ response.raise_for_status()
109
+ image = Image.open(BytesIO(response.content)).convert('RGB')
110
+ print("✅ Image downloaded successfully!")
111
+ return image
112
+
113
+ # Must be base64 then
114
+ elif isinstance(image_input, str):
115
+ print("🔍 Decoding base64 image...")
116
+ # Remove the data URL prefix if it exists
117
+ if "base64," in image_input:
118
+ image_input = image_input.split("base64,")[1]
119
+
120
+ image_data = base64.b64decode(image_input)
121
+ image = Image.open(BytesIO(image_data)).convert('RGB')
122
+ print("✅ Image decoded successfully!")
123
+ return image
124
+
125
+ except Exception as e:
126
+ print(f"❌ Couldn't process the image: {e}")
127
+ return None
128
+
129
+ return None
130
+
131
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
132
  """
133
+ Main processing function - where the magic happens!
134
 
135
  Args:
136
+ data: Input data with 'inputs' and optional 'parameters'
137
 
138
  Returns:
139
+ List with the generated response
140
  """
141
+ # Quick check - is our model ready?
142
  if self.use_pipeline is None:
143
  return [{
144
+ "generated_text": "Oops! Model couldn't load properly. Please check the deployment settings.",
145
+ "error": "Model initialization failed",
146
+ "handler": "Ubden® Team Enhanced Handler"
147
  }]
148
 
149
  try:
150
+ # Parse the inputs - flexible format support
151
  inputs = data.get("inputs", "")
152
+ text = ""
153
+ image = None
154
+
155
  if isinstance(inputs, dict):
156
+ # Dictionary input - check for text and image
157
  text = inputs.get("text", inputs.get("prompt", str(inputs)))
158
+
159
+ # Check for image in various formats
160
+ image_input = inputs.get("image", inputs.get("image_url", inputs.get("image_base64", None)))
161
+ if image_input:
162
+ image = self.process_image_input(image_input)
163
+ if image:
164
+ # For now, we'll add a note about the image since we're text-only
165
+ text = f"[Image provided - {image.size[0]}x{image.size[1]} pixels] {text}"
166
  else:
167
+ # Simple string input
168
  text = str(inputs)
169
 
170
  if not text:
171
+ return [{"generated_text": "Hey, I need some text to work with! Please provide an input."}]
172
 
173
+ # Get generation parameters with sensible defaults
174
  parameters = data.get("parameters", {})
175
  max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
176
  temperature = parameters.get("temperature", 0.7)
 
178
  do_sample = parameters.get("do_sample", True)
179
  repetition_penalty = parameters.get("repetition_penalty", 1.0)
180
 
181
+ # Using pipeline? Let's go!
182
  if self.use_pipeline:
183
  result = self.pipe(
184
  text,
 
187
  top_p=top_p,
188
  do_sample=do_sample,
189
  repetition_penalty=repetition_penalty,
190
+ return_full_text=False # Just the new stuff, not the input
191
  )
192
 
193
+ # Pipeline returns a list, let's handle it
194
  if isinstance(result, list) and len(result) > 0:
195
  return [{"generated_text": result[0].get("generated_text", "")}]
196
  else:
197
  return [{"generated_text": str(result)}]
198
 
199
+ # Manual generation mode
200
  else:
201
+ # Tokenize the input
202
  encoded = self.tokenizer(
203
  text,
204
  return_tensors="pt",
 
211
  if attention_mask is not None:
212
  attention_mask = attention_mask.to(self.device)
213
 
214
+ # Generate the response
215
  with torch.no_grad():
216
  outputs = self.model.generate(
217
  input_ids,
 
225
  eos_token_id=self.tokenizer.eos_token_id
226
  )
227
 
228
+ # Decode only the new tokens (not the input)
229
  generated_ids = outputs[0][input_ids.shape[-1]:]
230
  generated_text = self.tokenizer.decode(
231
  generated_ids,
 
236
  return [{"generated_text": generated_text}]
237
 
238
  except Exception as e:
239
+ error_msg = f"Something went wrong during generation: {str(e)}"
240
+ print(f"❌ {error_msg}")
241
  return [{
242
  "generated_text": "",
243
+ "error": error_msg,
244
+ "handler": "Ubden® Team Enhanced Handler"
245
  }]