"""Inference wrapper for nvidia/LocateAnything-3B.""" from __future__ import annotations from typing import Any import torch from PIL import Image from transformers import AutoModel, AutoProcessor, AutoTokenizer from src.config import ( DEVICE, DTYPE, GENERATION_MODE, MAX_NEW_TOKENS, MODEL_ID, TEMPERATURE, ) from src.parsing import ParseResult, parse_boxes class LocateAnythingWorker: """Stateful worker that loads LocateAnything-3B once and serves queries.""" def __init__( self, model_path: str = MODEL_ID, device: str = DEVICE, dtype_str: str = DTYPE, ) -> None: self.device = device self.dtype = getattr(torch, dtype_str, torch.bfloat16) self.model_path = model_path self._loaded = False self.tokenizer = None self.processor = None self.model = None def load(self) -> None: """Load model, tokenizer, and processor. Call once at startup.""" if self._loaded: return self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) self.model = ( AutoModel.from_pretrained( self.model_path, torch_dtype=self.dtype, trust_remote_code=True, ) .to(self.device) .eval() ) self._loaded = True @torch.no_grad() def predict( self, image: Image.Image, question: str, generation_mode: str = GENERATION_MODE, max_new_tokens: int = MAX_NEW_TOKENS, temperature: float = TEMPERATURE, ) -> dict[str, Any]: """Run inference on an image with a text prompt. Returns dict with 'answer', optionally 'history' and 'stats'. """ if not self._loaded: self.load() messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": question}, ], } ] text = self.processor.py_apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) images, videos = self.processor.process_vision_info(messages) inputs = self.processor( text=[text], images=images, videos=videos, return_tensors="pt" ).to(self.device) pixel_values = inputs["pixel_values"].to(self.dtype) input_ids = inputs["input_ids"] image_grid_hws = inputs.get("image_grid_hws", None) response = self.model.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=inputs["attention_mask"], image_grid_hws=image_grid_hws, tokenizer=self.tokenizer, max_new_tokens=max_new_tokens, use_cache=True, generation_mode=generation_mode, temperature=temperature, do_sample=True, top_p=0.9, repetition_penalty=1.1, verbose=False, ) result: dict[str, Any] = {"answer": response[0] if isinstance(response, tuple) else response} if isinstance(response, tuple) and len(response) >= 3: result["history"] = response[1] result["stats"] = response[2] return result def detect(self, image: Image.Image, categories: list[str], **kwargs: Any) -> dict[str, Any]: """Object detection with multiple categories.""" cats = "".join(categories) prompt = f"Locate all the instances that matches the following description: {cats}." return self.predict(image, prompt, **kwargs) def ground_single(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]: """Phrase grounding — single instance.""" prompt = f"Locate a single instance that matches the following description: {phrase}." return self.predict(image, prompt, **kwargs) def ground_multi(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]: """Phrase grounding — multiple instances.""" prompt = f"Locate all the instances that match the following description: {phrase}." return self.predict(image, prompt, **kwargs) def run_localization( image: Image.Image, prompt: str, worker: LocateAnythingWorker | None = None, ) -> tuple[Image.Image, str, ParseResult]: """High-level entry point: run localization and return annotated image + results. Args: image: Input PIL image. prompt: Natural language prompt. worker: Pre-loaded worker instance. If None, creates and loads one. Returns: Tuple of (annotated_image, raw_output, parse_result). """ from src.visualization import create_no_detection_overlay, draw_boxes if worker is None: worker = LocateAnythingWorker() worker.load() result = worker.predict(image, prompt) raw_output = result.get("answer", "") img_w, img_h = image.size parsed = parse_boxes(raw_output, img_w, img_h) if parsed.boxes: annotated = draw_boxes(image, parsed.boxes) else: annotated = create_no_detection_overlay(image) return annotated, raw_output, parsed