Spaces:
Running
Running
| """Inference wrapper for nvidia/LocateAnything-3B.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoModel, AutoProcessor, AutoTokenizer | |
| from src.config import ( | |
| DEVICE, | |
| DTYPE, | |
| GENERATION_MODE, | |
| MAX_NEW_TOKENS, | |
| MODEL_ID, | |
| TEMPERATURE, | |
| ) | |
| from src.parsing import ParseResult, parse_boxes | |
| class LocateAnythingWorker: | |
| """Stateful worker that loads LocateAnything-3B once and serves queries.""" | |
| def __init__( | |
| self, | |
| model_path: str = MODEL_ID, | |
| device: str = DEVICE, | |
| dtype_str: str = DTYPE, | |
| ) -> None: | |
| self.device = device | |
| self.dtype = getattr(torch, dtype_str, torch.bfloat16) | |
| self.model_path = model_path | |
| self._loaded = False | |
| self.tokenizer = None | |
| self.processor = None | |
| self.model = None | |
| def load(self) -> None: | |
| """Load model, tokenizer, and processor. Call once at startup.""" | |
| if self._loaded: | |
| return | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) | |
| self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) | |
| self.model = ( | |
| AutoModel.from_pretrained( | |
| self.model_path, | |
| torch_dtype=self.dtype, | |
| trust_remote_code=True, | |
| ) | |
| .to(self.device) | |
| .eval() | |
| ) | |
| self._loaded = True | |
| def predict( | |
| self, | |
| image: Image.Image, | |
| question: str, | |
| generation_mode: str = GENERATION_MODE, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| temperature: float = TEMPERATURE, | |
| ) -> dict[str, Any]: | |
| """Run inference on an image with a text prompt. | |
| Returns dict with 'answer', optionally 'history' and 'stats'. | |
| """ | |
| if not self._loaded: | |
| self.load() | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": question}, | |
| ], | |
| } | |
| ] | |
| text = self.processor.py_apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| images, videos = self.processor.process_vision_info(messages) | |
| inputs = self.processor( | |
| text=[text], images=images, videos=videos, return_tensors="pt" | |
| ).to(self.device) | |
| pixel_values = inputs["pixel_values"].to(self.dtype) | |
| input_ids = inputs["input_ids"] | |
| image_grid_hws = inputs.get("image_grid_hws", None) | |
| response = self.model.generate( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=inputs["attention_mask"], | |
| image_grid_hws=image_grid_hws, | |
| tokenizer=self.tokenizer, | |
| max_new_tokens=max_new_tokens, | |
| use_cache=True, | |
| generation_mode=generation_mode, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| verbose=False, | |
| ) | |
| result: dict[str, Any] = {"answer": response[0] if isinstance(response, tuple) else response} | |
| if isinstance(response, tuple) and len(response) >= 3: | |
| result["history"] = response[1] | |
| result["stats"] = response[2] | |
| return result | |
| def detect(self, image: Image.Image, categories: list[str], **kwargs: Any) -> dict[str, Any]: | |
| """Object detection with multiple categories.""" | |
| cats = "</c>".join(categories) | |
| prompt = f"Locate all the instances that matches the following description: {cats}." | |
| return self.predict(image, prompt, **kwargs) | |
| def ground_single(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]: | |
| """Phrase grounding — single instance.""" | |
| prompt = f"Locate a single instance that matches the following description: {phrase}." | |
| return self.predict(image, prompt, **kwargs) | |
| def ground_multi(self, image: Image.Image, phrase: str, **kwargs: Any) -> dict[str, Any]: | |
| """Phrase grounding — multiple instances.""" | |
| prompt = f"Locate all the instances that match the following description: {phrase}." | |
| return self.predict(image, prompt, **kwargs) | |
| def run_localization( | |
| image: Image.Image, | |
| prompt: str, | |
| worker: LocateAnythingWorker | None = None, | |
| ) -> tuple[Image.Image, str, ParseResult]: | |
| """High-level entry point: run localization and return annotated image + results. | |
| Args: | |
| image: Input PIL image. | |
| prompt: Natural language prompt. | |
| worker: Pre-loaded worker instance. If None, creates and loads one. | |
| Returns: | |
| Tuple of (annotated_image, raw_output, parse_result). | |
| """ | |
| from src.visualization import create_no_detection_overlay, draw_boxes | |
| if worker is None: | |
| worker = LocateAnythingWorker() | |
| worker.load() | |
| result = worker.predict(image, prompt) | |
| raw_output = result.get("answer", "") | |
| img_w, img_h = image.size | |
| parsed = parse_boxes(raw_output, img_w, img_h) | |
| if parsed.boxes: | |
| annotated = draw_boxes(image, parsed.boxes) | |
| else: | |
| annotated = create_no_detection_overlay(image) | |
| return annotated, raw_output, parsed | |