| """MiniCPM-V 4.6 wrapper. Loads the model and runs inference on film scans.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import re |
| from pathlib import Path |
| from typing import Any |
|
|
| logger = logging.getLogger(__name__) |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[2] |
| LOCAL_MODEL_PATH = REPO_ROOT / "checkpoints" / "minicpm-v-4.6-merged" |
| HF_MODEL_ID = "Lonelyguyse1/halide-vision" |
| BASE_MODEL_ID = "openbmb/MiniCPM-V-4_6" |
|
|
| DOWNSAMPLE_MODE = os.getenv("HALIDE_DOWNSAMPLE_MODE", "4x") |
| MAX_SLICE_NUMS = int(os.getenv("HALIDE_MAX_SLICE_NUMS", "36")) |
| MAX_NEW_TOKENS = int(os.getenv("HALIDE_MAX_NEW_TOKENS", "3072")) |
|
|
| DETECTION_PROMPT = ( |
| "You are a film defect detection engine. Analyze the film scan and detect " |
| "all visible defects. Output a JSON object with a 'defects' array. Each " |
| "defect has: 'label' (dust, dirt, scratch, long_hair, short_hair), 'bbox' " |
| "(normalized [x_min, y_min, x_max, y_max] from 0.0 to 1.0). Output JSON " |
| "only, no explanation." |
| ) |
|
|
|
|
| def _resolve_model_path() -> str: |
| """Pick local merged model if present, else HF repo, else base model id.""" |
| if LOCAL_MODEL_PATH.exists() and (LOCAL_MODEL_PATH / "config.json").exists(): |
| logger.info("Using local merged model at %s", LOCAL_MODEL_PATH) |
| return str(LOCAL_MODEL_PATH) |
| if os.getenv("HF_TOKEN"): |
| logger.info("Using HF Hub repo %s", HF_MODEL_ID) |
| return HF_MODEL_ID |
| logger.info("Falling back to base model %s", BASE_MODEL_ID) |
| return BASE_MODEL_ID |
|
|
|
|
| class MiniCPMVDetector: |
| """Lazy-loading wrapper around MiniCPM-V 4.6 for film defect detection.""" |
|
|
| def __init__(self, model_path: str | None = None) -> None: |
| self._model_path = model_path or _resolve_model_path() |
| self._model: Any = None |
| self._processor: Any = None |
| self._dtype: Any = None |
| self._device: str = "cpu" |
|
|
| @property |
| def model_path(self) -> str: |
| return self._model_path |
|
|
| def load(self) -> None: |
| if self._model is not None: |
| return |
| import torch |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
| logger.info("Loading MiniCPM-V 4.6 from %s", self._model_path) |
| self._processor = AutoProcessor.from_pretrained( |
| self._model_path, trust_remote_code=True |
| ) |
| self._dtype = torch.bfloat16 |
| self._model = AutoModelForImageTextToText.from_pretrained( |
| self._model_path, |
| torch_dtype=self._dtype, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
| self._device = str(next(self._model.parameters()).device) |
| logger.info("Model loaded on %s", self._device) |
|
|
| def detect(self, image: Any) -> dict: |
| """Run defect detection on a PIL image. Returns parsed JSON dict.""" |
| import torch |
|
|
| if self._model is None: |
| self.load() |
|
|
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": DETECTION_PROMPT}, |
| ], |
| } |
| ] |
|
|
| inputs = self._processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt", |
| downsample_mode=DOWNSAMPLE_MODE, |
| max_slice_nums=MAX_SLICE_NUMS, |
| ).to(self._device) |
|
|
| with torch.inference_mode(): |
| generated = self._model.generate( |
| **inputs, |
| downsample_mode=DOWNSAMPLE_MODE, |
| max_new_tokens=MAX_NEW_TOKENS, |
| do_sample=False, |
| ) |
|
|
| trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated)] |
| text = self._processor.batch_decode( |
| trimmed, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| )[0] |
|
|
| return _parse_defect_json(text) |
|
|
| def close(self) -> None: |
| if self._model is not None: |
| del self._model |
| self._model = None |
| if self._processor is not None: |
| del self._processor |
| self._processor = None |
|
|
|
|
| def _parse_defect_json(text: str) -> dict: |
| """Extract and parse the first JSON object from model output.""" |
| text = text.strip() |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| pass |
|
|
| match = re.search(r"\{[\s\S]*\}", text) |
| if not match: |
| logger.warning("No JSON found in model output: %r", text[:200]) |
| return {"defects": [], "_raw": text, "_parse_error": "no_json_object"} |
| try: |
| return json.loads(match.group(0)) |
| except json.JSONDecodeError as exc: |
| logger.warning("JSON parse error: %s; raw: %r", exc, text[:200]) |
| return {"defects": [], "_raw": text, "_parse_error": str(exc)} |
|
|
|
|
| _default_detector: MiniCPMVDetector | None = None |
|
|
|
|
| def get_detector() -> MiniCPMVDetector: |
| global _default_detector |
| if _default_detector is None: |
| _default_detector = MiniCPMVDetector() |
| return _default_detector |
|
|