Update handler.py
Browse files- handler.py +5 -26
handler.py
CHANGED
|
@@ -10,6 +10,7 @@ import json
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
|
|
|
| 13 |
from PIL import Image
|
| 14 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
| 15 |
from transformers.utils import logging
|
|
@@ -77,8 +78,6 @@ class PixtralForRegression(nn.Module):
|
|
| 77 |
am = attention_mask.to(last_h.device).long()
|
| 78 |
ids = input_ids.to(last_h.device)
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
if self.pooling == "mean_image_tokens":
|
| 83 |
# Mean over all [IMG] placeholder tokens
|
| 84 |
img_mask = (ids == self.image_token_id) & (am == 1) # [B,L]
|
|
@@ -102,11 +101,6 @@ class PixtralForRegression(nn.Module):
|
|
| 102 |
bsz = last_h.size(0)
|
| 103 |
return last_h[torch.arange(bsz, device=last_h.device), idx] # [B,H]
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
raise ValueError(f"Unknown pooling: {self.pooling}")
|
| 108 |
-
|
| 109 |
-
|
| 110 |
raise ValueError(f"Unknown pooling: {self.pooling}")
|
| 111 |
|
| 112 |
def forward(self, input_ids, attention_mask, pixel_values, **kwargs):
|
|
@@ -146,7 +140,7 @@ class PixtralForRegression(nn.Module):
|
|
| 146 |
self._dbg = True
|
| 147 |
img_mask = (input_ids == self.image_token_id) & (attention_mask == 1)
|
| 148 |
print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist())
|
| 149 |
-
pooled = self._pool(last_h, attention_mask)
|
| 150 |
raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1)
|
| 151 |
preds = F.softplus(raw) + 1.0
|
| 152 |
return {"logits": preds}
|
|
@@ -249,21 +243,6 @@ class EndpointHandler:
|
|
| 249 |
raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.")
|
| 250 |
return str(subcat).strip()
|
| 251 |
|
| 252 |
-
def _build_chat_text(self, prompt: str) -> str:
|
| 253 |
-
messages = [
|
| 254 |
-
{
|
| 255 |
-
"role": "user",
|
| 256 |
-
"content": [
|
| 257 |
-
{"type": "image"},
|
| 258 |
-
{"type": "text", "text": prompt}
|
| 259 |
-
],
|
| 260 |
-
}
|
| 261 |
-
]
|
| 262 |
-
return self.processor.apply_chat_template(
|
| 263 |
-
messages,
|
| 264 |
-
add_generation_prompt=True,
|
| 265 |
-
tokenize=False,
|
| 266 |
-
)
|
| 267 |
|
| 268 |
def _build_regression_text(self, prompt: str) -> str:
|
| 269 |
"""
|
|
@@ -285,8 +264,8 @@ class EndpointHandler:
|
|
| 285 |
messages,
|
| 286 |
add_generation_prompt=False,
|
| 287 |
tokenize=False,
|
| 288 |
-
)
|
| 289 |
-
return chat
|
| 290 |
|
| 291 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
| 292 |
inputs = data.get("inputs", data)
|
|
@@ -301,7 +280,7 @@ class EndpointHandler:
|
|
| 301 |
if not sub_category:
|
| 302 |
raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.")
|
| 303 |
|
| 304 |
-
prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category)
|
| 305 |
|
| 306 |
image = self._decode_image(image_b64)
|
| 307 |
image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side)))
|
|
|
|
| 10 |
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
from PIL import Image
|
| 15 |
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
| 16 |
from transformers.utils import logging
|
|
|
|
| 78 |
am = attention_mask.to(last_h.device).long()
|
| 79 |
ids = input_ids.to(last_h.device)
|
| 80 |
|
|
|
|
|
|
|
| 81 |
if self.pooling == "mean_image_tokens":
|
| 82 |
# Mean over all [IMG] placeholder tokens
|
| 83 |
img_mask = (ids == self.image_token_id) & (am == 1) # [B,L]
|
|
|
|
| 101 |
bsz = last_h.size(0)
|
| 102 |
return last_h[torch.arange(bsz, device=last_h.device), idx] # [B,H]
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
raise ValueError(f"Unknown pooling: {self.pooling}")
|
| 105 |
|
| 106 |
def forward(self, input_ids, attention_mask, pixel_values, **kwargs):
|
|
|
|
| 140 |
self._dbg = True
|
| 141 |
img_mask = (input_ids == self.image_token_id) & (attention_mask == 1)
|
| 142 |
print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist())
|
| 143 |
+
pooled = self._pool(last_h, attention_mask, input_ids)
|
| 144 |
raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1)
|
| 145 |
preds = F.softplus(raw) + 1.0
|
| 146 |
return {"logits": preds}
|
|
|
|
| 243 |
raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.")
|
| 244 |
return str(subcat).strip()
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
def _build_regression_text(self, prompt: str) -> str:
|
| 248 |
"""
|
|
|
|
| 264 |
messages,
|
| 265 |
add_generation_prompt=False,
|
| 266 |
tokenize=False,
|
| 267 |
+
)
|
| 268 |
+
return chat
|
| 269 |
|
| 270 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
| 271 |
inputs = data.get("inputs", data)
|
|
|
|
| 280 |
if not sub_category:
|
| 281 |
raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.")
|
| 282 |
|
| 283 |
+
prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() + "\n\nANSWER:"
|
| 284 |
|
| 285 |
image = self._decode_image(image_b64)
|
| 286 |
image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side)))
|