halide / models /vision /minicpm_wrapper.py
Lonelyguyse1's picture
Initial Halide Space: pipeline, UI, autumn theme
f55d589 verified
Raw
History Blame Contribute Delete
5.17 kB
"""MiniCPM-V 4.6 wrapper. Loads the model and runs inference on film scans."""
from __future__ import annotations
import json
import logging
import os
import re
from pathlib import Path
from typing import Any
logger = logging.getLogger(__name__)
REPO_ROOT = Path(__file__).resolve().parents[2]
LOCAL_MODEL_PATH = REPO_ROOT / "checkpoints" / "minicpm-v-4.6-merged"
HF_MODEL_ID = "Lonelyguyse1/halide-vision"
BASE_MODEL_ID = "openbmb/MiniCPM-V-4_6"
DOWNSAMPLE_MODE = os.getenv("HALIDE_DOWNSAMPLE_MODE", "4x")
MAX_SLICE_NUMS = int(os.getenv("HALIDE_MAX_SLICE_NUMS", "36"))
MAX_NEW_TOKENS = int(os.getenv("HALIDE_MAX_NEW_TOKENS", "3072"))
DETECTION_PROMPT = (
"You are a film defect detection engine. Analyze the film scan and detect "
"all visible defects. Output a JSON object with a 'defects' array. Each "
"defect has: 'label' (dust, dirt, scratch, long_hair, short_hair), 'bbox' "
"(normalized [x_min, y_min, x_max, y_max] from 0.0 to 1.0). Output JSON "
"only, no explanation."
)
def _resolve_model_path() -> str:
"""Pick local merged model if present, else HF repo, else base model id."""
if LOCAL_MODEL_PATH.exists() and (LOCAL_MODEL_PATH / "config.json").exists():
logger.info("Using local merged model at %s", LOCAL_MODEL_PATH)
return str(LOCAL_MODEL_PATH)
if os.getenv("HF_TOKEN"):
logger.info("Using HF Hub repo %s", HF_MODEL_ID)
return HF_MODEL_ID
logger.info("Falling back to base model %s", BASE_MODEL_ID)
return BASE_MODEL_ID
class MiniCPMVDetector:
"""Lazy-loading wrapper around MiniCPM-V 4.6 for film defect detection."""
def __init__(self, model_path: str | None = None) -> None:
self._model_path = model_path or _resolve_model_path()
self._model: Any = None
self._processor: Any = None
self._dtype: Any = None
self._device: str = "cpu"
@property
def model_path(self) -> str:
return self._model_path
def load(self) -> None:
if self._model is not None:
return
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
logger.info("Loading MiniCPM-V 4.6 from %s", self._model_path)
self._processor = AutoProcessor.from_pretrained(
self._model_path, trust_remote_code=True
)
self._dtype = torch.bfloat16
self._model = AutoModelForImageTextToText.from_pretrained(
self._model_path,
torch_dtype=self._dtype,
device_map="auto",
trust_remote_code=True,
)
self._device = str(next(self._model.parameters()).device)
logger.info("Model loaded on %s", self._device)
def detect(self, image: Any) -> dict:
"""Run defect detection on a PIL image. Returns parsed JSON dict."""
import torch
if self._model is None:
self.load()
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": DETECTION_PROMPT},
],
}
]
inputs = self._processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
downsample_mode=DOWNSAMPLE_MODE,
max_slice_nums=MAX_SLICE_NUMS,
).to(self._device)
with torch.inference_mode():
generated = self._model.generate(
**inputs,
downsample_mode=DOWNSAMPLE_MODE,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, generated)]
text = self._processor.batch_decode(
trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return _parse_defect_json(text)
def close(self) -> None:
if self._model is not None:
del self._model
self._model = None
if self._processor is not None:
del self._processor
self._processor = None
def _parse_defect_json(text: str) -> dict:
"""Extract and parse the first JSON object from model output."""
text = text.strip()
try:
return json.loads(text)
except json.JSONDecodeError:
pass
match = re.search(r"\{[\s\S]*\}", text)
if not match:
logger.warning("No JSON found in model output: %r", text[:200])
return {"defects": [], "_raw": text, "_parse_error": "no_json_object"}
try:
return json.loads(match.group(0))
except json.JSONDecodeError as exc:
logger.warning("JSON parse error: %s; raw: %r", exc, text[:200])
return {"defects": [], "_raw": text, "_parse_error": str(exc)}
_default_detector: MiniCPMVDetector | None = None
def get_detector() -> MiniCPMVDetector:
global _default_detector
if _default_detector is None:
_default_detector = MiniCPMVDetector()
return _default_detector