Kalaoke commited on
Commit
117c5b6
·
verified ·
1 Parent(s): 3e2055d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -26
handler.py CHANGED
@@ -10,6 +10,7 @@ import json
10
 
11
  import torch
12
  import torch.nn as nn
 
13
  from PIL import Image
14
  from transformers import AutoProcessor, LlavaForConditionalGeneration
15
  from transformers.utils import logging
@@ -77,8 +78,6 @@ class PixtralForRegression(nn.Module):
77
  am = attention_mask.to(last_h.device).long()
78
  ids = input_ids.to(last_h.device)
79
 
80
-
81
-
82
  if self.pooling == "mean_image_tokens":
83
  # Mean over all [IMG] placeholder tokens
84
  img_mask = (ids == self.image_token_id) & (am == 1) # [B,L]
@@ -102,11 +101,6 @@ class PixtralForRegression(nn.Module):
102
  bsz = last_h.size(0)
103
  return last_h[torch.arange(bsz, device=last_h.device), idx] # [B,H]
104
 
105
-
106
-
107
- raise ValueError(f"Unknown pooling: {self.pooling}")
108
-
109
-
110
  raise ValueError(f"Unknown pooling: {self.pooling}")
111
 
112
  def forward(self, input_ids, attention_mask, pixel_values, **kwargs):
@@ -146,7 +140,7 @@ class PixtralForRegression(nn.Module):
146
  self._dbg = True
147
  img_mask = (input_ids == self.image_token_id) & (attention_mask == 1)
148
  print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist())
149
- pooled = self._pool(last_h, attention_mask)
150
  raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1)
151
  preds = F.softplus(raw) + 1.0
152
  return {"logits": preds}
@@ -249,21 +243,6 @@ class EndpointHandler:
249
  raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.")
250
  return str(subcat).strip()
251
 
252
- def _build_chat_text(self, prompt: str) -> str:
253
- messages = [
254
- {
255
- "role": "user",
256
- "content": [
257
- {"type": "image"},
258
- {"type": "text", "text": prompt}
259
- ],
260
- }
261
- ]
262
- return self.processor.apply_chat_template(
263
- messages,
264
- add_generation_prompt=True,
265
- tokenize=False,
266
- )
267
 
268
  def _build_regression_text(self, prompt: str) -> str:
269
  """
@@ -285,8 +264,8 @@ class EndpointHandler:
285
  messages,
286
  add_generation_prompt=False,
287
  tokenize=False,
288
- ).rstrip()
289
- return chat + "\n\nANSWER:"
290
 
291
  def __call__(self, data: Dict[str, Any]) -> Any:
292
  inputs = data.get("inputs", data)
@@ -301,7 +280,7 @@ class EndpointHandler:
301
  if not sub_category:
302
  raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.")
303
 
304
- prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category)
305
 
306
  image = self._decode_image(image_b64)
307
  image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side)))
 
10
 
11
  import torch
12
  import torch.nn as nn
13
+ import torch.nn.functional as F
14
  from PIL import Image
15
  from transformers import AutoProcessor, LlavaForConditionalGeneration
16
  from transformers.utils import logging
 
78
  am = attention_mask.to(last_h.device).long()
79
  ids = input_ids.to(last_h.device)
80
 
 
 
81
  if self.pooling == "mean_image_tokens":
82
  # Mean over all [IMG] placeholder tokens
83
  img_mask = (ids == self.image_token_id) & (am == 1) # [B,L]
 
101
  bsz = last_h.size(0)
102
  return last_h[torch.arange(bsz, device=last_h.device), idx] # [B,H]
103
 
 
 
 
 
 
104
  raise ValueError(f"Unknown pooling: {self.pooling}")
105
 
106
  def forward(self, input_ids, attention_mask, pixel_values, **kwargs):
 
140
  self._dbg = True
141
  img_mask = (input_ids == self.image_token_id) & (attention_mask == 1)
142
  print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist())
143
+ pooled = self._pool(last_h, attention_mask, input_ids)
144
  raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1)
145
  preds = F.softplus(raw) + 1.0
146
  return {"logits": preds}
 
243
  raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.")
244
  return str(subcat).strip()
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
  def _build_regression_text(self, prompt: str) -> str:
248
  """
 
264
  messages,
265
  add_generation_prompt=False,
266
  tokenize=False,
267
+ )
268
+ return chat
269
 
270
  def __call__(self, data: Dict[str, Any]) -> Any:
271
  inputs = data.get("inputs", data)
 
280
  if not sub_category:
281
  raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.")
282
 
283
+ prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() + "\n\nANSWER:"
284
 
285
  image = self._decode_image(image_b64)
286
  image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side)))