Update handler.py
Browse files- handler.py +11 -4
handler.py
CHANGED
|
@@ -27,7 +27,7 @@ def _ensure_llava():
|
|
| 27 |
_ensure_llava()
|
| 28 |
|
| 29 |
# ---- LLaVA parçaları (demo akışı) ----
|
| 30 |
-
from llava.model.builder import load_pretrained_model
|
| 31 |
from llava.mm_utils import tokenizer_image_token, process_images
|
| 32 |
from llava.constants import (
|
| 33 |
IMAGE_TOKEN_INDEX,
|
|
@@ -38,6 +38,13 @@ from llava.constants import (
|
|
| 38 |
from llava.conversation import conv_templates
|
| 39 |
from llava.utils import disable_torch_init
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
# Varsayılanlar
|
| 42 |
DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
|
| 43 |
MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
|
@@ -57,7 +64,7 @@ class EndpointHandler:
|
|
| 57 |
def __init__(self, path: str = "") -> None:
|
| 58 |
disable_torch_init()
|
| 59 |
|
| 60 |
-
# PULSE-7B HF
|
| 61 |
if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
|
| 62 |
model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
|
| 63 |
elif os.getenv("HF_MODEL_ID", "").strip():
|
|
@@ -136,7 +143,7 @@ class EndpointHandler:
|
|
| 136 |
image_tensors = process_images([pil], self.image_processor, self.model.config)
|
| 137 |
image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
|
| 138 |
|
| 139 |
-
# 3) tokenize (image token
|
| 140 |
input_ids = tokenizer_image_token(
|
| 141 |
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 142 |
).to(self.model.device, non_blocking=True)
|
|
@@ -159,4 +166,4 @@ class EndpointHandler:
|
|
| 159 |
output_ids = self.model.generate(input_ids, images=image_tensors, **gen_kwargs)
|
| 160 |
|
| 161 |
text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 162 |
-
return [{"generated_text": text}]
|
|
|
|
| 27 |
_ensure_llava()
|
| 28 |
|
| 29 |
# ---- LLaVA parçaları (demo akışı) ----
|
| 30 |
+
from llava.model.builder import load_pretrained_model
|
| 31 |
from llava.mm_utils import tokenizer_image_token, process_images
|
| 32 |
from llava.constants import (
|
| 33 |
IMAGE_TOKEN_INDEX,
|
|
|
|
| 38 |
from llava.conversation import conv_templates
|
| 39 |
from llava.utils import disable_torch_init
|
| 40 |
|
| 41 |
+
# Eksik fonksiyonu kendimiz tanımlıyoruz
|
| 42 |
+
def get_model_name_from_path(model_path):
|
| 43 |
+
"""Extract model name from path"""
|
| 44 |
+
if "/" in model_path:
|
| 45 |
+
return model_path.split("/")[-1]
|
| 46 |
+
return model_path
|
| 47 |
+
|
| 48 |
# Varsayılanlar
|
| 49 |
DEFAULT_CONV_MODE = os.getenv("LLAVA_CONV_MODE", "llava_v2")
|
| 50 |
MAX_NEW_TOKENS_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
|
|
|
|
| 64 |
def __init__(self, path: str = "") -> None:
|
| 65 |
disable_torch_init()
|
| 66 |
|
| 67 |
+
# PULSE-7B HF'den/yerelden nereden yükleniyorsa yolu belirle
|
| 68 |
if os.getenv("HF_MODEL_LOCAL_DIR", "").strip():
|
| 69 |
model_path = os.getenv("HF_MODEL_LOCAL_DIR").strip()
|
| 70 |
elif os.getenv("HF_MODEL_ID", "").strip():
|
|
|
|
| 143 |
image_tensors = process_images([pil], self.image_processor, self.model.config)
|
| 144 |
image_tensors = image_tensors.to(self.model.device, dtype=self.model.dtype, non_blocking=True)
|
| 145 |
|
| 146 |
+
# 3) tokenize (image token'ı gömülü)
|
| 147 |
input_ids = tokenizer_image_token(
|
| 148 |
prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 149 |
).to(self.model.device, non_blocking=True)
|
|
|
|
| 166 |
output_ids = self.model.generate(input_ids, images=image_tensors, **gen_kwargs)
|
| 167 |
|
| 168 |
text = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
| 169 |
+
return [{"generated_text": text}]
|