obj_localizer / src /inference.py
3v324v23's picture
fix: resolve all ruff lint errors
cf388f7
Raw
History Blame Contribute Delete
5.43 kB
"""Inference wrapper for nvidia/LocateAnything-3B."""
from __future__ import annotations
from typing import Any
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from src.config import (
DEVICE,
DTYPE,
GENERATION_MODE,
MAX_NEW_TOKENS,
MODEL_ID,
TEMPERATURE,
)
from src.parsing import ParseResult, parse_boxes
class LocateAnythingWorker:
"""Stateful worker that loads LocateAnything-3B once and serves queries."""
def __init__(
self,
model_path: str = MODEL_ID,
device: str = DEVICE,
dtype_str: str = DTYPE,
) -> None:
self.device = device
self.dtype = getattr(torch, dtype_str, torch.bfloat16)
self.model_path = model_path
self._loaded = False
self.tokenizer = None
self.processor = None
self.model = None
def load(self) -> None:
"""Load model, tokenizer, and processor. Call once at startup."""
if self._loaded:
return
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
self.model = (
AutoModel.from_pretrained(
self.model_path,
torch_dtype=self.dtype,
trust_remote_code=True,
)
.to(self.device)
.eval()
)
self._loaded = True
@torch.no_grad()
def predict(
self,
image: Image.Image,
question: str,
generation_mode: str = GENERATION_MODE,
max_new_tokens: int = MAX_NEW_TOKENS,
temperature: float = TEMPERATURE,
) -> dict[str, Any]:
"""Run inference on an image with a text prompt.
Returns dict with 'answer', optionally 'history' and 'stats'.
"""
if not self._loaded:
self.load()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": question},
],
}
]
text = self.processor.py_apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
images, videos = self.processor.process_vision_info(messages)
inputs = self.processor(
text=[text], images=images, videos=videos, return_tensors="pt"
).to(self.device)
pixel_values = inputs["pixel_values"].to(self.dtype)
input_ids = inputs["input_ids"]
image_grid_hws = inputs.get("image_grid_hws", None)
response = self.model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=inputs["attention_mask"],
image_grid_hws=image_grid_hws,
tokenizer=self.tokenizer,
max_new_tokens=max_new_tokens,
use_cache=True,
generation_mode=generation_mode,
temperature=temperature,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
verbose=False,
)
result: dict[str, Any] = {"answer": response[0] if isinstance(response, tuple) else response}
if isinstance(response, tuple) and len(response) >= 3:
result["history"] = response[1]
result["stats"] = response[2]
return result
def detect(self, image: Image.Image, categories: list[str], **kwargs: Any) -> dict[str, Any]:
"""Object detection with multiple categories."""
cats = "</c>".join(categories)
prompt = f"Locate all the instances that matches the following description: {cats}."
return self.predict(image, prompt, **kwargs)
def ground_single(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]:
"""Phrase grounding — single instance."""
prompt = f"Locate a single instance that matches the following description: {phrase}."
return self.predict(image, prompt, **kwargs)
def ground_multi(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]:
"""Phrase grounding — multiple instances."""
prompt = f"Locate all the instances that match the following description: {phrase}."
return self.predict(image, prompt, **kwargs)
def run_localization(
image: Image.Image,
prompt: str,
worker: LocateAnythingWorker | None = None,
) -> tuple[Image.Image, str, ParseResult]:
"""High-level entry point: run localization and return annotated image + results.
Args:
image: Input PIL image.
prompt: Natural language prompt.
worker: Pre-loaded worker instance. If None, creates and loads one.
Returns:
Tuple of (annotated_image, raw_output, parse_result).
"""
from src.visualization import create_no_detection_overlay, draw_boxes
if worker is None:
worker = LocateAnythingWorker()
worker.load()
result = worker.predict(image, prompt)
raw_output = result.get("answer", "")
img_w, img_h = image.size
parsed = parse_boxes(raw_output, img_w, img_h)
if parsed.boxes:
annotated = draw_boxes(image, parsed.boxes)
else:
annotated = create_no_detection_overlay(image)
return annotated, raw_output, parsed