from __future__ import annotations import base64 import os from dataclasses import dataclass from io import BytesIO from typing import Any, Dict, Optional import glob import json import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration from transformers.utils import logging from safetensors.torch import load_file as safe_load_file logger = logging.get_logger(__name__) logging.set_verbosity_info() DEFAULT_PROMPT = """ Here is a photo of some food waste. Your task is to provide a weight estimate (in grams) for the following food waste: {SUB_CATEGORY} While providing your weight estimate, please keep in mind the following: *** there might be other food items on the photo, I want you to provide the weight estimate for {SUB_CATEGORY} only! *** make sure to provide your best guess, never say 0! OUTPUT (digits only, no extra text): WEIGHT """.strip() @dataclass class InferenceConfig: max_length: int = 2048 max_side: int = 512 default_last_k: int = 32 class PixtralForRegression(nn.Module): def __init__(self, base, pooling: str = "mean_image_tokens", last_k: int = 32): super().__init__() self.base = base self.pooling = pooling self.last_k = last_k self.image_token_id = int(getattr(getattr(base, "config", None), "image_token_index", 10)) hidden = None if hasattr(base, "language_model") and hasattr(base.language_model, "config"): hidden = getattr(base.language_model.config, "hidden_size", None) if hidden is None and hasattr(base, "config") and hasattr(base.config, "text_config"): hidden = getattr(base.config.text_config, "hidden_size", None) if hidden is None and hasattr(base, "config"): hidden = getattr(base.config, "hidden_size", None) if hidden is None: raise ValueError("Could not infer hidden_size from model config.") self.reg_head = nn.Sequential( nn.LayerNorm(hidden), nn.Linear(hidden, 1024), nn.GELU(), nn.Dropout(0.1), nn.Linear(1024, 1), ).to(dtype=torch.float32) def _pool(self, last_h, attention_mask, input_ids): """ last_h: [B, L, H] attention_mask: [B, L] (1 for real tokens, 0 for padding) returns: [B, H] """ am = attention_mask.to(last_h.device).long() ids = input_ids.to(last_h.device) if self.pooling == "mean_image_tokens": # Mean over all [IMG] placeholder tokens img_mask = (ids == self.image_token_id) & (am == 1) # [B,L] denom = img_mask.sum(dim=1, keepdim=True).clamp(min=1) # [B,1] pooled = (last_h * img_mask.unsqueeze(-1)).sum(dim=1) / denom return pooled # [B,H] if self.pooling == "mean_last_k_tokens": # Mean over last K *non-pad* tokens bsz, L, H = last_h.shape out = [] lengths = am.sum(dim=1).clamp(min=1) # [B] for b in range(bsz): end = int(lengths[b].item()) start = max(0, end - self.last_k) out.append(last_h[b, start:end].mean(dim=0)) return torch.stack(out, dim=0) # [B,H] if self.pooling == "last_nonpad": idx = (am.sum(dim=1) - 1).clamp(min=0) # [B] bsz = last_h.size(0) return last_h[torch.arange(bsz, device=last_h.device), idx] # [B,H] raise ValueError(f"Unknown pooling: {self.pooling}") def forward(self, input_ids, attention_mask, pixel_values, **kwargs): kwargs.pop("labels", None) kwargs.pop("num_items_in_batch", None) image_sizes = kwargs.pop("image_sizes", None) if image_sizes is not None and torch.is_tensor(image_sizes) and image_sizes.ndim == 1: image_sizes = image_sizes.unsqueeze(0) if image_sizes is None and torch.is_tensor(pixel_values) and pixel_values.ndim == 4: B, _, H, W = pixel_values.shape image_sizes = torch.tensor([[H, W]] * B, dtype=torch.long, device=pixel_values.device) base_kwargs = dict( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, output_hidden_states=True, return_dict=True, ) if image_sizes is not None: base_kwargs["image_sizes"] = image_sizes out = self.base(**base_kwargs, **kwargs) hs = getattr(out, "hidden_states", None) if hs is None and hasattr(out, "language_model_outputs"): hs = getattr(out.language_model_outputs, "hidden_states", None) if hs is None: raise ValueError("No hidden_states found in model outputs.") last_h = hs[-1] if not hasattr(self, "_dbg"): self._dbg = True img_mask = (input_ids == self.image_token_id) & (attention_mask == 1) print("IMG tokens per sample:", img_mask.sum(dim=1)[:4].tolist()) pooled = self._pool(last_h, attention_mask, input_ids) raw = self.reg_head(pooled.to(torch.float32)).squeeze(-1) preds = F.softplus(raw) + 1.0 return {"logits": preds} class EndpointHandler: def __init__(self, path: str = ".") -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" self.cfg = InferenceConfig() logger.info("Initializing EndpointHandler on device: %s", self.device) logger.info("Model dir listing: %s", sorted(os.listdir(path))) logger.info("Bin shards present: %s", sorted(glob.glob(os.path.join(path, "pytorch_model-*.bin")))) self.y_mean: Optional[float] = None self.y_std: Optional[float] = None # --- processor --- self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) if hasattr(self.processor, "tokenizer"): self.processor.tokenizer.truncation_side = "left" tok = self.processor.tokenizer if tok.pad_token_id is None and tok.eos_token_id is not None: tok.pad_token = tok.eos_token tok.pad_token_id = tok.eos_token_id ip = getattr(self.processor, "image_processor", None) if ip is not None: if hasattr(ip, "do_resize"): ip.do_resize = False if hasattr(ip, "do_center_crop"): ip.do_center_crop = False dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 base = LlavaForConditionalGeneration.from_pretrained( path, torch_dtype=dtype, low_cpu_mem_usage=True, trust_remote_code=True, ).to(self.device) if hasattr(base, "config"): base.config.use_cache = False self._dtype = base.dtype # --- load regression head (safetensors json) --- head_st = os.path.join(path, "regression_head.safetensors") head_js = os.path.join(path, "regression_head.json") if not os.path.exists(head_st) or not os.path.exists(head_js): raise ValueError( f"Missing regression head files. Expected BOTH:\n- {head_st}\n- {head_js}" ) with open(head_js, "r") as f: head_cfg = json.load(f) pooling = head_cfg.get("pooling", "mean_last_k_tokens") last_k = head_cfg.get("last_k", self.cfg.default_last_k) norm = head_cfg.get("label_norm") or {} if "mean" in norm and "std" in norm: self.y_mean = float(norm["mean"]) self.y_std = float(norm["std"]) if self.y_std <= 0: raise ValueError(f"Invalid label_norm.std={self.y_std} in regression_head.json") self.model = PixtralForRegression(base, pooling=pooling, last_k=last_k) head_sd = safe_load_file(head_st) self.model.reg_head.load_state_dict(head_sd, strict=True) self.model.to(self.device) self.model.eval() logger.info("Loaded base + regression head. pooling=%s dtype=%s", pooling, str(self._dtype)) if self.y_mean is not None and self.y_std is not None: logger.info("Label normalization enabled: mean=%.6f std=%.6f", self.y_mean, self.y_std) @staticmethod def _decode_image(image_b64: str) -> Image.Image: img_bytes = base64.b64decode(image_b64) return Image.open(BytesIO(img_bytes)).convert("RGB") @staticmethod def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: w, h = img.size m = max(w, h) if m <= max_side: return img scale = max_side / m return img.resize((int(w * scale), int(h * scale)), resample=Image.Resampling.LANCZOS) @staticmethod def _resolve_sub_category(inputs: Dict[str, Any]) -> str: subcat = inputs.get("sub_category") if not subcat or not str(subcat).strip(): raise ValueError("Missing 'subcat' (or 'sub_category') in 'inputs'.") return str(subcat).strip() def _build_regression_text(self, prompt: str) -> str: """ - apply_chat_template on user message only - add_generation_prompt=False - append "\n\nANSWER:" """ messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], } ] chat = self.processor.apply_chat_template( messages, add_generation_prompt=False, tokenize=False, ) return chat def __call__(self, data: Dict[str, Any]) -> Any: inputs = data.get("inputs", data) debug = bool(inputs.get("debug", False)) image_b64: Optional[str] = inputs.get("image") if not image_b64: raise ValueError("Missing 'image' field (base64-encoded) in 'inputs'.") sub_category = self._resolve_sub_category(inputs) if not sub_category: raise ValueError("Missing 'sub_category' (or 'subcat') in 'inputs'.") prompt = DEFAULT_PROMPT.format(SUB_CATEGORY=sub_category).rstrip() + "\n\nANSWER:" image = self._decode_image(image_b64) image = self._resize_max_side(image, max_side=int(inputs.get("max_side", self.cfg.max_side))) max_length = int(inputs.get("max_length", self.cfg.max_length)) text = self._build_regression_text(prompt) enc = self.processor( text=[text], images=[image], return_tensors="pt", truncation=True, max_length=max_length, padding=False, ) enc = {k: (v.to(self.device) if torch.is_tensor(v) else v) for k, v in enc.items()} enc["pixel_values"] = enc["pixel_values"].to(self.device, dtype=self._dtype) w, h = image.size image_sizes = torch.tensor([[h, w]], dtype=torch.long, device=self.device) if debug: tok = self.processor.tokenizer last_idx = int(enc["attention_mask"][0].sum().item()) - 1 last_id = int(enc["input_ids"][0, last_idx].item()) logger.info("last_token_id=%d last_token='%s'", last_id, tok.decode([last_id], skip_special_tokens=False)) logger.info("pixel_values shape: %s", tuple(enc["pixel_values"].shape)) logger.info("image_sizes: %s", image_sizes[0].tolist()) with torch.inference_mode(): out = self.model( input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], pixel_values=enc["pixel_values"], image_sizes=image_sizes, ) pred_norm = float(out["logits"].item()) # If trained with mean/std normalization, unnormalize here if self.y_mean is not None and self.y_std is not None: pred_g = pred_norm * float(self.y_std) + float(self.y_mean) else: pred_g = pred_norm pred_g = max(1.0, pred_g) if debug: logger.info("pred_norm=%.6f -> pred_g=%.3f", pred_norm, pred_g) return int(round(pred_g))