Kalaoke's picture
Update handler.py
117c5b6 verified
from __future__ import annotations
import base64
import os
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Dict, Optional
import glob
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers.utils import logging
from safetensors.torch import load_file as safe_load_file
logger = logging.get_logger(__name__)
logging.set_verbosity_info()
DEFAULT_PROMPT = """
Here is a photo of some food waste.
Your task is to provide a weight estimate (in grams) for the following food waste:
{SUB_CATEGORY}
While providing your weight estimate, please keep in mind the following:
*** there might be other food items on the photo, I want you to provide the weight estimate for {SUB_CATEGORY} only!
*** make sure to provide your best guess, never say 0!
OUTPUT (digits only, no extra text): WEIGHT
""".strip()
@dataclass
class InferenceConfig:
max_length: int = 2048
max_side: int = 512
default_last_k: int = 32
class PixtralForRegression(nn.Module):
def __init__(self, base, pooling: str = "mean_image_tokens", last_k: int = 32):
super().__init__()
self.base = base
self.pooling = pooling
self.last_k = last_k
self.image_token_id = int(getattr(getattr(base, "config", None), "image_token_index", 10))
hidden = None
if hasattr(base, "language_model") and hasattr(base.language_model, "config"):
hidden = getattr(base.language_model.config, "hidden_size", None)
if hidden is None and hasattr(base, "config") and hasattr(base.config, "text_config"):
hidden = getattr(base.config.text_config, "hidden_size", None)
if hidden is None and hasattr(base, "config"):
hidden = getattr(base.config, "hidden_size", None)
if hidden is None:
raise ValueError("Could not infer hidden_size from model config.")
self.reg_head = nn.Sequential(
nn.LayerNorm(hidden),
nn.Linear(hidden, 1024),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(1024, 1),
).to(dtype=torch.float32)
def _pool(self, last_h, attention_mask, input_ids):
"""
last_h: [B, L, H]
attention_mask: [B, L] (1 for real tokens, 0 for padding)
returns: [B, H]
"""
am = attention_mask.to(last_h.device).long()
ids = input_ids.to(last_h.device)
if self.pooling == "mean_image_tokens":
# Mean over all [IMG] placeholder tokens
img_mask = (ids == self.image_token_id) & (am == 1) # [B,L]
denom = img_mask.sum(dim=1, keepdim=True).clamp(min=1) # [B,1]
pooled = (last_h * img_mask.unsqueeze(-1)).sum(dim=1) / denom
return pooled # [B,H]
if self.pooling == "mean_last_k_tokens":
# Mean over last K *non-pad* tokens
bsz, L, H = last_h.shape
out = []
lengths = am.sum(dim=1).clamp(min=1) # [B]
for b in range(bsz):
end = int(lengths[b].item())
start = max(0, end - self.last_k)
out.append(last_h[b, start:end].mean(dim=0))
return torch.stack(out, dim=0) # [B,H]
if self.pooling == "last_nonpad":
idx = (am.sum(dim=1) - 1).clamp(min=0) # [B]
bsz = last_h.size(0)
return last_h[torch.arange(bsz, device=last_h.device), idx] # [B,H]
raise ValueError(f"Unknown pooling: {self.pooling}")
def forward(self, input_ids, attention_mask, pixel_values, **kwargs):
kwargs.pop("labels", None)
kwargs.pop("num_items_in_batch", None)
image_sizes = kwargs.pop("image_sizes", None)
if image_sizes is not None and torch.is_tensor(image_sizes) and image_sizes.ndim == 1:
image_sizes = image_sizes.unsqueeze(0)
if image_sizes is None and torch.is_tensor(pixel_values) and pixel_values.ndim == 4:
B, _, H, W = pixel_values.shape
image_sizes = torch.tensor([[H, W]] * B, dtype=torch.long, device=pixel_values.device)
base_kwargs = dict(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True,
)
if image_sizes is not None:
base_kwargs["image_sizes"] = image_sizes
out = self.base(**base_kwargs, **kwargs)
hs = getattr(out, "hidden_states", None)
if hs is None and hasattr(out, "language_model_outputs"):
hs = getattr(out.language_model_outputs, "hidden_states", None)
if hs is None:
raise ValueError("No hidden_states found in model outputs.")
last_h = hs[-1]
if not hasattr(self, "_dbg"):
self._dbg = True
img_mask = (input_ids == self.image_token_id) & (attention_mask == 1)
print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist())
pooled = self._pool(last_h, attention_mask, input_ids)
raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1)
preds = F.softplus(raw) + 1.0
return {"logits": preds}
class EndpointHandler:
def __init__(self, path: str = ".") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.cfg = InferenceConfig()
logger.info("Initializing EndpointHandler on device: %s", self.device)
logger.info("Model dir listing: %s", sorted(os.listdir(path)))
logger.info("Bin shards present: %s", sorted(glob.glob(os.path.join(path, "pytorch_model-*.bin"))))
self.y_mean: Optional[float] = None
self.y_std: Optional[float] = None
# --- processor ---
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
if hasattr(self.processor, "tokenizer"):
self.processor.tokenizer.truncation_side = "left"
tok = self.processor.tokenizer
if tok.pad_token_id is None and tok.eos_token_id is not None:
tok.pad_token = tok.eos_token
tok.pad_token_id = tok.eos_token_id
ip = getattr(self.processor, "image_processor", None)
if ip is not None:
if hasattr(ip, "do_resize"):
ip.do_resize = False
if hasattr(ip, "do_center_crop"):
ip.do_center_crop = False
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
base = LlavaForConditionalGeneration.from_pretrained(
path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to(self.device)
if hasattr(base, "config"):
base.config.use_cache = False
self._dtype = base.dtype
# --- load regression head (safetensors json) ---
head_st = os.path.join(path, "regression_head.safetensors")
head_js = os.path.join(path, "regression_head.json")
if not os.path.exists(head_st) or not os.path.exists(head_js):
raise ValueError(
f"Missing regression head files. Expected BOTH:\n- {head_st}\n- {head_js}"
)
with open(head_js, "r") as f:
head_cfg = json.load(f)
pooling = head_cfg.get("pooling", "mean_last_k_tokens")
last_k = head_cfg.get("last_k", self.cfg.default_last_k)
norm = head_cfg.get("label_norm") or {}
if "mean" in norm and "std" in norm:
self.y_mean = float(norm["mean"])
self.y_std = float(norm["std"])
if self.y_std <= 0:
raise ValueError(f"Invalid label_norm.std={self.y_std} in regression_head.json")
self.model = PixtralForRegression(base, pooling=pooling, last_k=last_k)
head_sd = safe_load_file(head_st)
self.model.reg_head.load_state_dict(head_sd, strict=True)
self.model.to(self.device)
self.model.eval()
logger.info("Loaded base + regression head. pooling=%s dtype=%s", pooling, str(self._dtype))
if self.y_mean is not None and self.y_std is not None:
logger.info("Label normalization enabled: mean=%.6f std=%.6f", self.y_mean, self.y_std)
@staticmethod
def _decode_image(image_b64: str) -> Image.Image:
img_bytes = base64.b64decode(image_b64)
return Image.open(BytesIO(img_bytes)).convert("RGB")
@staticmethod
def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image:
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / m
return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS)
@staticmethod
def _resolve_sub_category(inputs: Dict[str, Any]) -> str:
subcat = inputs.get("sub_category")
if not subcat or not str(subcat).strip():
raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.")
return str(subcat).strip()
def _build_regression_text(self, prompt: str) -> str:
"""
- apply_chat_template on user message only
- add_generation_prompt=False
- append "\n\nANSWER:"
"""
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
}
]
chat = self.processor.apply_chat_template(
messages,
add_generation_prompt=False,
tokenize=False,
)
return chat
def __call__(self, data: Dict[str, Any]) -> Any:
inputs = data.get("inputs", data)
debug = bool(inputs.get("debug", False))
image_b64: Optional[str] = inputs.get("image")
if not image_b64:
raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.")
sub_category = self._resolve_sub_category(inputs)
if not sub_category:
raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.")
prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() + "\n\nANSWER:"
image = self._decode_image(image_b64)
image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side)))
max_length = int(inputs.get("max_length", self.cfg.max_length))
text = self._build_regression_text(prompt)
enc = self.processor(
text=[text],
images=[image],
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=False,
)
enc = {k: (v.to(self.device) if torch.is_tensor(v) else v) for k, v in enc.items()}
enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self._dtype)
w, h = image.size
image_sizes = torch.tensor([[h, w]], dtype=torch.long, device=self.device)
if debug:
tok = self.processor.tokenizer
last_idx = int(enc["attention_mask"][0].sum().item()) - 1
last_id = int(enc["input_ids"][0, last_idx].item())
logger.info("last_token_id=%d last_token='%s'", last_id, tok.decode([last_id], skip_special_tokens=False))
logger.info("pixel_values shape: %s", tuple(enc["pixel_values"].shape))
logger.info("image_sizes: %s", image_sizes[0].tolist())
with torch.inference_mode():
out = self.model(
input_ids=enc["input_ids"],
attention_mask=enc["attention_mask"],
pixel_values=enc["pixel_values"],
image_sizes=image_sizes,
)
pred_norm = float(out["logits"].item())
# If trained with mean/std normalization, unnormalize here
if self.y_mean is not None and self.y_std is not None:
pred_g = pred_norm * float(self.y_std) + float(self.y_mean)
else:
pred_g = pred_norm
pred_g = max(1.0, pred_g)
if debug:
logger.info("pred_norm=%.6f -> pred_g=%.3f", pred_norm, pred_g)
return int(round(pred_g))