| | from __future__ import annotations |
| |
|
| | import base64 |
| | import os |
| | from dataclasses import dataclass |
| | from io import BytesIO |
| | from typing import Any, Dict, Optional |
| | import glob |
| | import json |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | from transformers import AutoProcessor, LlavaForConditionalGeneration |
| | from transformers.utils import logging |
| | from safetensors.torch import load_file as safe_load_file |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| | logging.set_verbosity_info() |
| |
|
| |
|
| | DEFAULT_PROMPT = """ |
| | Here is a photo of some food waste. |
| | |
| | Your task is to provide a weight estimate (in grams) for the following food waste: |
| | {SUB_CATEGORY} |
| | |
| | While providing your weight estimate, please keep in mind the following: |
| | *** there might be other food items on the photo, I want you to provide the weight estimate for {SUB_CATEGORY} only! |
| | *** make sure to provide your best guess, never say 0! |
| | |
| | OUTPUT (digits only, no extra text): WEIGHT |
| | """.strip() |
| |
|
| |
|
| | @dataclass |
| | class InferenceConfig: |
| | max_length: int = 2048 |
| | max_side: int = 512 |
| | default_last_k: int = 32 |
| |
|
| |
|
| | class PixtralForRegression(nn.Module): |
| | def __init__(self, base, pooling: str = "mean_image_tokens", last_k: int = 32): |
| | super().__init__() |
| | self.base = base |
| | self.pooling = pooling |
| | self.last_k = last_k |
| |
|
| | self.image_token_id = int(getattr(getattr(base, "config", None), "image_token_index", 10)) |
| |
|
| | hidden = None |
| | if hasattr(base, "language_model") and hasattr(base.language_model, "config"): |
| | hidden = getattr(base.language_model.config, "hidden_size", None) |
| | if hidden is None and hasattr(base, "config") and hasattr(base.config, "text_config"): |
| | hidden = getattr(base.config.text_config, "hidden_size", None) |
| | if hidden is None and hasattr(base, "config"): |
| | hidden = getattr(base.config, "hidden_size", None) |
| | if hidden is None: |
| | raise ValueError("Could not infer hidden_size from model config.") |
| |
|
| | self.reg_head = nn.Sequential( |
| | nn.LayerNorm(hidden), |
| | nn.Linear(hidden, 1024), |
| | nn.GELU(), |
| | nn.Dropout(0.1), |
| | nn.Linear(1024, 1), |
| | ).to(dtype=torch.float32) |
| |
|
| | def _pool(self, last_h, attention_mask, input_ids): |
| | """ |
| | last_h: [B, L, H] |
| | attention_mask: [B, L] (1 for real tokens, 0 for padding) |
| | returns: [B, H] |
| | """ |
| | am = attention_mask.to(last_h.device).long() |
| | ids = input_ids.to(last_h.device) |
| |
|
| | if self.pooling == "mean_image_tokens": |
| | |
| | img_mask = (ids == self.image_token_id) & (am == 1) |
| | denom = img_mask.sum(dim=1, keepdim=True).clamp(min=1) |
| | pooled = (last_h * img_mask.unsqueeze(-1)).sum(dim=1) / denom |
| | return pooled |
| |
|
| | if self.pooling == "mean_last_k_tokens": |
| | |
| | bsz, L, H = last_h.shape |
| | out = [] |
| | lengths = am.sum(dim=1).clamp(min=1) |
| | for b in range(bsz): |
| | end = int(lengths[b].item()) |
| | start = max(0, end - self.last_k) |
| | out.append(last_h[b, start:end].mean(dim=0)) |
| | return torch.stack(out, dim=0) |
| |
|
| | if self.pooling == "last_nonpad": |
| | idx = (am.sum(dim=1) - 1).clamp(min=0) |
| | bsz = last_h.size(0) |
| | return last_h[torch.arange(bsz, device=last_h.device), idx] |
| |
|
| | raise ValueError(f"Unknown pooling: {self.pooling}") |
| |
|
| | def forward(self, input_ids, attention_mask, pixel_values, **kwargs): |
| | kwargs.pop("labels", None) |
| | kwargs.pop("num_items_in_batch", None) |
| |
|
| | image_sizes = kwargs.pop("image_sizes", None) |
| |
|
| | if image_sizes is not None and torch.is_tensor(image_sizes) and image_sizes.ndim == 1: |
| | image_sizes = image_sizes.unsqueeze(0) |
| |
|
| | |
| | if image_sizes is None and torch.is_tensor(pixel_values) and pixel_values.ndim == 4: |
| | B, _, H, W = pixel_values.shape |
| | image_sizes = torch.tensor([[H, W]] * B, dtype=torch.long, device=pixel_values.device) |
| |
|
| | base_kwargs = dict( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | pixel_values=pixel_values, |
| | output_hidden_states=True, |
| | return_dict=True, |
| | ) |
| | if image_sizes is not None: |
| | base_kwargs["image_sizes"] = image_sizes |
| |
|
| | out = self.base(**base_kwargs, **kwargs) |
| |
|
| | hs = getattr(out, "hidden_states", None) |
| | if hs is None and hasattr(out, "language_model_outputs"): |
| | hs = getattr(out.language_model_outputs, "hidden_states", None) |
| | if hs is None: |
| | raise ValueError("No hidden_states found in model outputs.") |
| |
|
| | last_h = hs[-1] |
| | if not hasattr(self, "_dbg"): |
| | self._dbg = True |
| | img_mask = (input_ids == self.image_token_id) & (attention_mask == 1) |
| | print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist()) |
| | pooled = self._pool(last_h, attention_mask, input_ids) |
| | raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1) |
| | preds = F.softplus(raw) + 1.0 |
| | return {"logits": preds} |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ".") -> None: |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.cfg = InferenceConfig() |
| | logger.info("Initializing EndpointHandler on device: %s", self.device) |
| | logger.info("Model dir listing: %s", sorted(os.listdir(path))) |
| | logger.info("Bin shards present: %s", sorted(glob.glob(os.path.join(path, "pytorch_model-*.bin")))) |
| |
|
| | self.y_mean: Optional[float] = None |
| | self.y_std: Optional[float] = None |
| |
|
| | |
| | self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| |
|
| | |
| | if hasattr(self.processor, "tokenizer"): |
| | self.processor.tokenizer.truncation_side = "left" |
| |
|
| | tok = self.processor.tokenizer |
| | if tok.pad_token_id is None and tok.eos_token_id is not None: |
| | tok.pad_token = tok.eos_token |
| | tok.pad_token_id = tok.eos_token_id |
| |
|
| | ip = getattr(self.processor, "image_processor", None) |
| | if ip is not None: |
| | if hasattr(ip, "do_resize"): |
| | ip.do_resize = False |
| | if hasattr(ip, "do_center_crop"): |
| | ip.do_center_crop = False |
| |
|
| | dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
| | |
| | base = LlavaForConditionalGeneration.from_pretrained( |
| | path, |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | trust_remote_code=True, |
| | ).to(self.device) |
| |
|
| | if hasattr(base, "config"): |
| | base.config.use_cache = False |
| |
|
| | self._dtype = base.dtype |
| |
|
| | |
| | head_st = os.path.join(path, "regression_head.safetensors") |
| | head_js = os.path.join(path, "regression_head.json") |
| | if not os.path.exists(head_st) or not os.path.exists(head_js): |
| | raise ValueError( |
| | f"Missing regression head files. Expected BOTH:\n- {head_st}\n- {head_js}" |
| | ) |
| |
|
| | with open(head_js, "r") as f: |
| | head_cfg = json.load(f) |
| |
|
| | pooling = head_cfg.get("pooling", "mean_last_k_tokens") |
| | last_k = head_cfg.get("last_k", self.cfg.default_last_k) |
| |
|
| | norm = head_cfg.get("label_norm") or {} |
| | if "mean" in norm and "std" in norm: |
| | self.y_mean = float(norm["mean"]) |
| | self.y_std = float(norm["std"]) |
| | if self.y_std <= 0: |
| | raise ValueError(f"Invalid label_norm.std={self.y_std} in regression_head.json") |
| | |
| |
|
| | self.model = PixtralForRegression(base, pooling=pooling, last_k=last_k) |
| | head_sd = safe_load_file(head_st) |
| | self.model.reg_head.load_state_dict(head_sd, strict=True) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | logger.info("Loaded base + regression head. pooling=%s dtype=%s", pooling, str(self._dtype)) |
| | if self.y_mean is not None and self.y_std is not None: |
| | logger.info("Label normalization enabled: mean=%.6f std=%.6f", self.y_mean, self.y_std) |
| |
|
| | @staticmethod |
| | def _decode_image(image_b64: str) -> Image.Image: |
| | img_bytes = base64.b64decode(image_b64) |
| | return Image.open(BytesIO(img_bytes)).convert("RGB") |
| |
|
| | @staticmethod |
| | def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: |
| | w, h = img.size |
| | m = max(w, h) |
| | if m <= max_side: |
| | return img |
| | scale = max_side / m |
| | return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS) |
| |
|
| | @staticmethod |
| | def _resolve_sub_category(inputs: Dict[str, Any]) -> str: |
| | subcat = inputs.get("sub_category") |
| | if not subcat or not str(subcat).strip(): |
| | raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.") |
| | return str(subcat).strip() |
| |
|
| |
|
| | def _build_regression_text(self, prompt: str) -> str: |
| | """ |
| | - apply_chat_template on user message only |
| | - add_generation_prompt=False |
| | - append "\n\nANSWER:" |
| | """ |
| | messages = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "image"}, |
| | {"type": "text", "text": prompt}, |
| | |
| | ], |
| | } |
| | ] |
| | chat = self.processor.apply_chat_template( |
| | messages, |
| | add_generation_prompt=False, |
| | tokenize=False, |
| | ) |
| | return chat |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | inputs = data.get("inputs", data) |
| | debug = bool(inputs.get("debug", False)) |
| |
|
| | image_b64: Optional[str] = inputs.get("image") |
| | if not image_b64: |
| | raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.") |
| |
|
| | sub_category = self._resolve_sub_category(inputs) |
| | |
| | if not sub_category: |
| | raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.") |
| |
|
| | prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() + "\n\nANSWER:" |
| |
|
| | image = self._decode_image(image_b64) |
| | image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side))) |
| |
|
| | max_length = int(inputs.get("max_length", self.cfg.max_length)) |
| | text = self._build_regression_text(prompt) |
| |
|
| | enc = self.processor( |
| | text=[text], |
| | images=[image], |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=max_length, |
| | padding=False, |
| | ) |
| |
|
| | enc = {k: (v.to(self.device) if torch.is_tensor(v) else v) for k, v in enc.items()} |
| | enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self._dtype) |
| |
|
| | w, h = image.size |
| | image_sizes = torch.tensor([[h, w]], dtype=torch.long, device=self.device) |
| |
|
| | if debug: |
| | tok = self.processor.tokenizer |
| | last_idx = int(enc["attention_mask"][0].sum().item()) - 1 |
| | last_id = int(enc["input_ids"][0, last_idx].item()) |
| | logger.info("last_token_id=%d last_token='%s'", last_id, tok.decode([last_id], skip_special_tokens=False)) |
| | logger.info("pixel_values shape: %s", tuple(enc["pixel_values"].shape)) |
| | logger.info("image_sizes: %s", image_sizes[0].tolist()) |
| |
|
| | with torch.inference_mode(): |
| | out = self.model( |
| | input_ids=enc["input_ids"], |
| | attention_mask=enc["attention_mask"], |
| | pixel_values=enc["pixel_values"], |
| | image_sizes=image_sizes, |
| | ) |
| | pred_norm = float(out["logits"].item()) |
| | |
| | if self.y_mean is not None and self.y_std is not None: |
| | pred_g = pred_norm * float(self.y_std) + float(self.y_mean) |
| | else: |
| | pred_g = pred_norm |
| |
|
| | pred_g = max(1.0, pred_g) |
| | if debug: |
| | logger.info("pred_norm=%.6f -> pred_g=%.3f", pred_norm, pred_g) |
| | |
| | return int(round(pred_g)) |