CanerDedeoglu commited on
Commit
f870018
·
verified ·
1 Parent(s): 7f46e2b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -4
handler.py CHANGED
@@ -27,7 +27,7 @@ def _ensure_llava():
27
  _ensure_llava()
28
 
29
  # ---- LLaVA parçaları (demo akışı) ----
30
- from llava.model.builder import load_pretrained_model, get_model_name_from_path
31
  from llava.mm_utils import tokenizer_image_token, process_images
32
  from llava.constants import (
33
  IMAGE_TOKEN_INDEX,
@@ -38,6 +38,13 @@ from llava.constants import (
38
  from llava.conversation import conv_templates
39
  from llava.utils import disable_torch_init
40
 
 
 
 
 
 
 
 
41
  # Varsayılanlar
42
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
43
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
@@ -57,7 +64,7 @@ class EndpointHandler:
57
  def __init__(self, path: str = "") -> None:
58
  disable_torch_init()
59
 
60
- # PULSE-7B HFden/yerelden nereden yükleniyorsa yolu belirle
61
  if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
62
  model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
63
  elif os.getenv("HF_MODEL_ID", "").strip():
@@ -136,7 +143,7 @@ class EndpointHandler:
136
  image_tensors = process_images([pil], self.image_processor, self.model.config)
137
  image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
138
 
139
- # 3) tokenize (image token’ı gömülü)
140
  input_ids = tokenizer_image_token(
141
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
142
  ).to(self.model.device, non_blocking=True)
@@ -159,4 +166,4 @@ class EndpointHandler:
159
  output_ids = self.model.generate(input_ids, images=image_tensors, **gen_kwargs)
160
 
161
  text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
162
- return [{"generated_text": text}]
 
27
  _ensure_llava()
28
 
29
  # ---- LLaVA parçaları (demo akışı) ----
30
+ from llava.model.builder import load_pretrained_model
31
  from llava.mm_utils import tokenizer_image_token, process_images
32
  from llava.constants import (
33
  IMAGE_TOKEN_INDEX,
 
38
  from llava.conversation import conv_templates
39
  from llava.utils import disable_torch_init
40
 
41
+ # Eksik fonksiyonu kendimiz tanımlıyoruz
42
+ def get_model_name_from_path(model_path):
43
+ """Extract model name from path"""
44
+ if "/" in model_path:
45
+ return model_path.split("/")[-1]
46
+ return model_path
47
+
48
  # Varsayılanlar
49
  DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
50
  MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
 
64
  def __init__(self, path: str = "") -> None:
65
  disable_torch_init()
66
 
67
+ # PULSE-7B HF'den/yerelden nereden yükleniyorsa yolu belirle
68
  if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
69
  model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
70
  elif os.getenv("HF_MODEL_ID", "").strip():
 
143
  image_tensors = process_images([pil], self.image_processor, self.model.config)
144
  image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
145
 
146
+ # 3) tokenize (image token gömülü)
147
  input_ids = tokenizer_image_token(
148
  prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
149
  ).to(self.model.device, non_blocking=True)
 
166
  output_ids = self.model.generate(input_ids, images=image_tensors, **gen_kwargs)
167
 
168
  text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
169
+ return [{"generated_text": text}]