ubden commited on
Commit
b0ca732
·
verified ·
1 Parent(s): 6f28285

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +120 -129
handler.py CHANGED
@@ -1,128 +1,83 @@
1
  import torch
2
  from typing import Dict, List, Any
3
- import json
4
- import os
5
 
6
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
  """
10
  Initialize the handler for PULSE-7B model.
 
11
 
12
  Args:
13
- path: Path to the model directory
14
  """
15
- print(f"Initializing handler with path: {path}")
16
 
17
  # Device ayarla
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  print(f"Using device: {self.device}")
20
 
21
- # Config dosyasını manuel olarak yükle ve düzenle
22
- config_path = os.path.join(path, "config.json")
23
- if os.path.exists(config_path):
24
- with open(config_path, 'r') as f:
25
- config_data = json.load(f)
26
 
27
- # Model tipini geçici olarak değiştir
28
- original_model_type = config_data.get("model_type", "")
29
- print(f"Original model type: {original_model_type}")
 
 
 
 
 
 
 
 
 
 
30
 
31
- if original_model_type == "llava_llama":
32
- # Geçici config dosyası oluştur
33
- config_data["model_type"] = "llama"
34
- config_data["architectures"] = ["LlamaForCausalLM"]
35
-
36
- temp_config_path = os.path.join(path, "temp_config.json")
37
- with open(temp_config_path, 'w') as f:
38
- json.dump(config_data, f)
39
-
40
- # Llama model olarak yükle
41
- from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer
42
-
43
- try:
44
- # Tokenizer'ı yükle
45
- print("Loading tokenizer...")
46
- self.tokenizer = AutoTokenizer.from_pretrained(
47
- path,
48
- use_fast=False,
49
- trust_remote_code=True
50
- )
51
-
52
- # Model'i Llama olarak yükle
53
- print("Loading model as Llama...")
54
- self.model = LlamaForCausalLM.from_pretrained(
55
- path,
56
- config=temp_config_path,
57
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
58
- device_map="auto",
59
- low_cpu_mem_usage=True,
60
- ignore_mismatched_sizes=True
61
- )
62
-
63
- # Temp config'i sil
64
- if os.path.exists(temp_config_path):
65
- os.remove(temp_config_path)
66
-
67
- except Exception as e:
68
- print(f"Llama loading failed: {e}")
69
- # En basit yöntem: AutoModel kullan
70
- from transformers import AutoModel, AutoTokenizer
71
-
72
- self.tokenizer = AutoTokenizer.from_pretrained(
73
- path,
74
- trust_remote_code=True
75
- )
76
-
77
- self.model = AutoModel.from_pretrained(
78
- path,
79
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
80
- device_map="auto",
81
- trust_remote_code=True,
82
- ignore_mismatched_sizes=True
83
- )
84
- else:
85
- # Standart yükleme
86
- from transformers import AutoModelForCausalLM, AutoTokenizer
87
 
 
 
88
  self.tokenizer = AutoTokenizer.from_pretrained(
89
- path,
90
  trust_remote_code=True
91
  )
92
 
93
- self.model = AutoModelForCausalLM.from_pretrained(
94
- path,
 
 
95
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
96
  device_map="auto",
97
- trust_remote_code=True,
98
- low_cpu_mem_usage=True
99
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  else:
101
- # Config bulunamadı, direkt yüklemeyi dene
102
- print("Config not found, trying direct loading...")
103
- from transformers import AutoModelForCausalLM, AutoTokenizer
104
-
105
- self.tokenizer = AutoTokenizer.from_pretrained(
106
- "PULSE-ECG/PULSE-7B",
107
- trust_remote_code=True
108
- )
109
-
110
- self.model = AutoModelForCausalLM.from_pretrained(
111
- "PULSE-ECG/PULSE-7B",
112
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
113
- device_map="auto",
114
- trust_remote_code=True,
115
- ignore_mismatched_sizes=True
116
- )
117
-
118
- # Padding token ayarla
119
- if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
120
- self.tokenizer.pad_token = self.tokenizer.eos_token
121
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
122
-
123
- # Model'i eval moduna al
124
- self.model.eval()
125
- print("Handler initialization complete!")
126
 
127
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
128
  """
@@ -134,6 +89,13 @@ class EndpointHandler:
134
  Returns:
135
  List containing the generated response
136
  """
 
 
 
 
 
 
 
137
  try:
138
  # Input'ları al
139
  inputs = data.get("inputs", "")
@@ -145,46 +107,75 @@ class EndpointHandler:
145
  if not text:
146
  return [{"generated_text": "Please provide an input text."}]
147
 
148
- # Parametreleri al (basit tut)
149
  parameters = data.get("parameters", {})
150
- max_new_tokens = min(parameters.get("max_new_tokens", 128), 512)
151
  temperature = parameters.get("temperature", 0.7)
 
152
  do_sample = parameters.get("do_sample", True)
 
153
 
154
- # Tokenize
155
- encoded = self.tokenizer(
156
- text,
157
- return_tensors="pt",
158
- truncation=True,
159
- max_length=1024
160
- )
161
-
162
- input_ids = encoded["input_ids"].to(self.device)
163
-
164
- # Generate
165
- with torch.no_grad():
166
- outputs = self.model.generate(
167
- input_ids,
168
  max_new_tokens=max_new_tokens,
169
- temperature=temperature if do_sample else 1.0,
 
170
  do_sample=do_sample,
171
- pad_token_id=self.tokenizer.pad_token_id,
172
- eos_token_id=self.tokenizer.eos_token_id
173
  )
 
 
 
 
 
 
174
 
175
- # Decode
176
- generated_text = self.tokenizer.decode(
177
- outputs[0],
178
- skip_special_tokens=True
179
- )
180
-
181
- # Remove input from output
182
- if generated_text.startswith(text):
183
- generated_text = generated_text[len(text):].strip()
184
-
185
- return [{"generated_text": generated_text}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  except Exception as e:
188
  error_msg = f"Error during generation: {str(e)}"
189
  print(error_msg)
190
- return [{"generated_text": "", "error": error_msg}]
 
 
 
 
1
  import torch
2
  from typing import Dict, List, Any
 
 
3
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
  """
8
  Initialize the handler for PULSE-7B model.
9
+ Direct reference to the original model.
10
 
11
  Args:
12
+ path: Path to the model directory (not used, we load from HF hub)
13
  """
14
+ print("Initializing PULSE-7B handler...")
15
 
16
  # Device ayarla
17
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  print(f"Using device: {self.device}")
19
 
20
+ try:
21
+ # Pipeline kullan - en basit ve güvenilir yöntem
22
+ from transformers import pipeline
 
 
23
 
24
+ print("Loading model from HuggingFace Hub...")
25
+ self.pipe = pipeline(
26
+ "text-generation",
27
+ model="PULSE-ECG/PULSE-7B",
28
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
29
+ device=0 if torch.cuda.is_available() else -1,
30
+ trust_remote_code=True,
31
+ model_kwargs={
32
+ "low_cpu_mem_usage": True,
33
+ "use_safetensors": True
34
+ }
35
+ )
36
+ print("Model loaded successfully via pipeline!")
37
 
38
+ except Exception as e:
39
+ print(f"Pipeline loading failed: {e}")
40
+ print("Trying alternative loading method...")
41
+
42
+ try:
43
+ # Alternatif: Model ve tokenizer'ı ayrı yükle
44
+ from transformers import AutoTokenizer, LlamaForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Tokenizer'ı yükle
47
+ print("Loading tokenizer...")
48
  self.tokenizer = AutoTokenizer.from_pretrained(
49
+ "PULSE-ECG/PULSE-7B",
50
  trust_remote_code=True
51
  )
52
 
53
+ # Model'i Llama olarak yükle
54
+ print("Loading model as Llama...")
55
+ self.model = LlamaForCausalLM.from_pretrained(
56
+ "PULSE-ECG/PULSE-7B",
57
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
58
  device_map="auto",
59
+ low_cpu_mem_usage=True,
60
+ trust_remote_code=True
61
  )
62
+
63
+ # Padding token ayarla
64
+ if self.tokenizer.pad_token is None:
65
+ self.tokenizer.pad_token = self.tokenizer.eos_token
66
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
67
+
68
+ self.model.eval()
69
+ self.use_pipeline = False
70
+ print("Model loaded successfully via direct loading!")
71
+
72
+ except Exception as e2:
73
+ print(f"Alternative loading also failed: {e2}")
74
+ # En son çare: Basit bir fallback mesajı
75
+ self.pipe = None
76
+ self.model = None
77
+ self.tokenizer = None
78
+ self.use_pipeline = None
79
  else:
80
+ self.use_pipeline = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
83
  """
 
89
  Returns:
90
  List containing the generated response
91
  """
92
+ # Model yüklenemediyse hata döndür
93
+ if self.use_pipeline is None:
94
+ return [{
95
+ "generated_text": "Model could not be loaded. Please check the deployment configuration.",
96
+ "error": "Model initialization failed"
97
+ }]
98
+
99
  try:
100
  # Input'ları al
101
  inputs = data.get("inputs", "")
 
107
  if not text:
108
  return [{"generated_text": "Please provide an input text."}]
109
 
110
+ # Parametreleri al
111
  parameters = data.get("parameters", {})
112
+ max_new_tokens = min(parameters.get("max_new_tokens", 256), 1024)
113
  temperature = parameters.get("temperature", 0.7)
114
+ top_p = parameters.get("top_p", 0.95)
115
  do_sample = parameters.get("do_sample", True)
116
+ repetition_penalty = parameters.get("repetition_penalty", 1.0)
117
 
118
+ # Pipeline kullanıyorsak
119
+ if self.use_pipeline:
120
+ result = self.pipe(
121
+ text,
 
 
 
 
 
 
 
 
 
 
122
  max_new_tokens=max_new_tokens,
123
+ temperature=temperature,
124
+ top_p=top_p,
125
  do_sample=do_sample,
126
+ repetition_penalty=repetition_penalty,
127
+ return_full_text=False # Sadece yeni üretilen metni döndür
128
  )
129
+
130
+ # Pipeline list döndürür
131
+ if isinstance(result, list) and len(result) > 0:
132
+ return [{"generated_text": result[0].get("generated_text", "")}]
133
+ else:
134
+ return [{"generated_text": str(result)}]
135
 
136
+ # Manuel generation kullanıyorsak
137
+ else:
138
+ # Tokenize
139
+ encoded = self.tokenizer(
140
+ text,
141
+ return_tensors="pt",
142
+ truncation=True,
143
+ max_length=2048
144
+ )
145
+
146
+ input_ids = encoded["input_ids"].to(self.device)
147
+ attention_mask = encoded.get("attention_mask")
148
+ if attention_mask is not None:
149
+ attention_mask = attention_mask.to(self.device)
150
+
151
+ # Generate
152
+ with torch.no_grad():
153
+ outputs = self.model.generate(
154
+ input_ids,
155
+ attention_mask=attention_mask,
156
+ max_new_tokens=max_new_tokens,
157
+ temperature=temperature,
158
+ top_p=top_p,
159
+ do_sample=do_sample,
160
+ repetition_penalty=repetition_penalty,
161
+ pad_token_id=self.tokenizer.pad_token_id,
162
+ eos_token_id=self.tokenizer.eos_token_id
163
+ )
164
+
165
+ # Decode - sadece yeni tokenleri al
166
+ generated_ids = outputs[0][input_ids.shape[-1]:]
167
+ generated_text = self.tokenizer.decode(
168
+ generated_ids,
169
+ skip_special_tokens=True,
170
+ clean_up_tokenization_spaces=True
171
+ )
172
+
173
+ return [{"generated_text": generated_text}]
174
 
175
  except Exception as e:
176
  error_msg = f"Error during generation: {str(e)}"
177
  print(error_msg)
178
+ return [{
179
+ "generated_text": "",
180
+ "error": error_msg
181
+ }]