Spaces:
Running
Running
File size: 5,433 Bytes
23db765 cf388f7 23db765 cf388f7 23db765 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """Inference wrapper for nvidia/LocateAnything-3B."""
from __future__ import annotations
from typing import Any
import torch
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from src.config import (
DEVICE,
DTYPE,
GENERATION_MODE,
MAX_NEW_TOKENS,
MODEL_ID,
TEMPERATURE,
)
from src.parsing import ParseResult, parse_boxes
class LocateAnythingWorker:
"""Stateful worker that loads LocateAnything-3B once and serves queries."""
def __init__(
self,
model_path: str = MODEL_ID,
device: str = DEVICE,
dtype_str: str = DTYPE,
) -> None:
self.device = device
self.dtype = getattr(torch, dtype_str, torch.bfloat16)
self.model_path = model_path
self._loaded = False
self.tokenizer = None
self.processor = None
self.model = None
def load(self) -> None:
"""Load model, tokenizer, and processor. Call once at startup."""
if self._loaded:
return
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
self.model = (
AutoModel.from_pretrained(
self.model_path,
torch_dtype=self.dtype,
trust_remote_code=True,
)
.to(self.device)
.eval()
)
self._loaded = True
@torch.no_grad()
def predict(
self,
image: Image.Image,
question: str,
generation_mode: str = GENERATION_MODE,
max_new_tokens: int = MAX_NEW_TOKENS,
temperature: float = TEMPERATURE,
) -> dict[str, Any]:
"""Run inference on an image with a text prompt.
Returns dict with 'answer', optionally 'history' and 'stats'.
"""
if not self._loaded:
self.load()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": question},
],
}
]
text = self.processor.py_apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
images, videos = self.processor.process_vision_info(messages)
inputs = self.processor(
text=[text], images=images, videos=videos, return_tensors="pt"
).to(self.device)
pixel_values = inputs["pixel_values"].to(self.dtype)
input_ids = inputs["input_ids"]
image_grid_hws = inputs.get("image_grid_hws", None)
response = self.model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=inputs["attention_mask"],
image_grid_hws=image_grid_hws,
tokenizer=self.tokenizer,
max_new_tokens=max_new_tokens,
use_cache=True,
generation_mode=generation_mode,
temperature=temperature,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
verbose=False,
)
result: dict[str, Any] = {"answer": response[0] if isinstance(response, tuple) else response}
if isinstance(response, tuple) and len(response) >= 3:
result["history"] = response[1]
result["stats"] = response[2]
return result
def detect(self, image: Image.Image, categories: list[str], **kwargs: Any) -> dict[str, Any]:
"""Object detection with multiple categories."""
cats = "</c>".join(categories)
prompt = f"Locate all the instances that matches the following description: {cats}."
return self.predict(image, prompt, **kwargs)
def ground_single(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]:
"""Phrase grounding — single instance."""
prompt = f"Locate a single instance that matches the following description: {phrase}."
return self.predict(image, prompt, **kwargs)
def ground_multi(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]:
"""Phrase grounding — multiple instances."""
prompt = f"Locate all the instances that match the following description: {phrase}."
return self.predict(image, prompt, **kwargs)
def run_localization(
image: Image.Image,
prompt: str,
worker: LocateAnythingWorker | None = None,
) -> tuple[Image.Image, str, ParseResult]:
"""High-level entry point: run localization and return annotated image + results.
Args:
image: Input PIL image.
prompt: Natural language prompt.
worker: Pre-loaded worker instance. If None, creates and loads one.
Returns:
Tuple of (annotated_image, raw_output, parse_result).
"""
from src.visualization import create_no_detection_overlay, draw_boxes
if worker is None:
worker = LocateAnythingWorker()
worker.load()
result = worker.predict(image, prompt)
raw_output = result.get("answer", "")
img_w, img_h = image.size
parsed = parse_boxes(raw_output, img_w, img_h)
if parsed.boxes:
annotated = draw_boxes(image, parsed.boxes)
else:
annotated = create_no_detection_overlay(image)
return annotated, raw_output, parsed
|