vision-base / core /model.py
SPP
feat: polish UI, reveal animation, Tiny Titan badge, Oracle sampling
3e70669
Raw
History Blame Contribute Delete
2.57 kB
import spaces # ZeroGPU: must precede torch/transformers imports
import torch
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
MODEL_ID = "openbmb/MiniCPM-V-4.6"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
)
model.eval()
def _to_pil(img) -> Image.Image:
if isinstance(img, Image.Image):
return img.convert("RGB")
if hasattr(img, "__array__"):
import numpy as np
arr = img if isinstance(img, np.ndarray) else img.__array__()
return Image.fromarray(arr).convert("RGB")
if isinstance(img, str):
return Image.open(img).convert("RGB")
raise TypeError(f"Cannot convert {type(img)} to PIL Image")
@spaces.GPU(duration=120)
def vision_infer(
images,
instruction: str,
json_mode: bool = False,
max_tokens: int = 768,
do_sample: bool = False,
temperature: float = 0.7,
) -> str:
"""Single GPU entrypoint. images: PIL Image or list of PIL Images."""
if not isinstance(images, list):
images = [images]
model.to("cuda")
try:
pil_images = [_to_pil(img) for img in images]
if json_mode:
instruction = (
instruction
+ "\n\nRespond with ONLY valid JSON. No markdown fences, no prose, no explanation."
)
content = [{"type": "image", "image": img} for img in pil_images]
content.append({"type": "text", "text": instruction})
messages = [{"role": "user", "content": content}]
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
downsample_mode="16x",
max_slice_nums=36,
).to(model.device)
gen_kwargs: dict = {
"downsample_mode": "16x",
"max_new_tokens": max_tokens,
"do_sample": do_sample,
}
if do_sample:
gen_kwargs["temperature"] = temperature
with torch.no_grad():
generated_ids = model.generate(**inputs, **gen_kwargs)
trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return processor.batch_decode(
trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
finally:
model.to("cpu")