|
|
import os |
|
|
import dataclasses |
|
|
from typing import Optional, Dict, Any, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
|
|
|
from transformers import ( |
|
|
AutoProcessor, |
|
|
AutoModelForVision2Seq, |
|
|
AutoModelForCausalLM, |
|
|
) |
|
|
|
|
|
from peft import PeftModel, PeftConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FORCE_CPU = True |
|
|
DEVICE = torch.device("cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RESOLUTION_MAP = {0: 384, 1: 768, 2: 1024} |
|
|
|
|
|
def load_and_resize_image(img: Image.Image, max_size: Optional[int] = None) -> Image.Image: |
|
|
img = img.convert("RGB") |
|
|
if max_size is None: |
|
|
return img |
|
|
w, h = img.size |
|
|
if max(w, h) <= max_size: |
|
|
return img |
|
|
s = max_size / max(w, h) |
|
|
return img.resize((round(w * s), round(h * s)), Image.BICUBIC) |
|
|
|
|
|
def token_id_for_digit(tokenizer, digit: str) -> int: |
|
|
ids = tokenizer.encode(digit, add_special_tokens=False) |
|
|
if not ids: |
|
|
raise ValueError(f"Could not encode digit {digit!r}") |
|
|
return ids[-1] |
|
|
|
|
|
class GraniteDoclingGateHF: |
|
|
def __init__(self, adapter_repo: str, token: Optional[str] = None): |
|
|
self.device = DEVICE |
|
|
|
|
|
peft_cfg = PeftConfig.from_pretrained(adapter_repo, token=token) |
|
|
base_model_name = peft_cfg.base_model_name_or_path |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(adapter_repo, token=token) |
|
|
|
|
|
|
|
|
torch_dtype = torch.float32 |
|
|
|
|
|
base_model = AutoModelForVision2Seq.from_pretrained( |
|
|
base_model_name, torch_dtype=torch_dtype |
|
|
) |
|
|
|
|
|
self.model = PeftModel.from_pretrained(base_model, adapter_repo, token=token) |
|
|
self.model.to(self.device).eval() |
|
|
|
|
|
tok = self.processor.tokenizer |
|
|
self.class_token_ids = [ |
|
|
token_id_for_digit(tok, "0"), |
|
|
token_id_for_digit(tok, "1"), |
|
|
token_id_for_digit(tok, "2"), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict_probs(self, image: Image.Image, question: str): |
|
|
messages = [ |
|
|
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]} |
|
|
] |
|
|
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) |
|
|
inputs = self.processor(text=[prompt], images=[image], return_tensors="pt") |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items() if hasattr(v, "to")} |
|
|
|
|
|
outputs = self.model(**inputs) |
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
class_logits = next_token_logits[:, self.class_token_ids] |
|
|
probs = F.softmax(class_logits, dim=-1)[0].detach().float().cpu().tolist() |
|
|
return probs |
|
|
|
|
|
def predict_expected(self, image: Image.Image, question: str) -> float: |
|
|
probs = self.predict_probs(image, question) |
|
|
return float(sum(RESOLUTION_MAP[i] * probs[i] for i in range(3))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DOWNSTREAM_MODELS = { |
|
|
"ibm-granite/granite-vision-3.3-2b (recommended CPU)": "ibm-granite/granite-vision-3.3-2b", |
|
|
"HuggingFaceTB/SmolVLM-256M-Instruct (tiny CPU)": "HuggingFaceTB/SmolVLM-256M-Instruct", |
|
|
"google/paligemma-3b-mix-224 (CPU)": "google/paligemma-3b-mix-224", |
|
|
|
|
|
|
|
|
"Qwen/Qwen2.5-VL-3B-Instruct (slow CPU)": "Qwen/Qwen2.5-VL-3B-Instruct", |
|
|
"Qwen/Qwen2.5-VL-7B-Instruct (very slow CPU)": "Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
"Qwen/Qwen2.5-VL-72B-Instruct (not for CPU)": "Qwen/Qwen2.5-VL-72B-Instruct", |
|
|
"Qwen/Qwen3-VL-8B-Instruct (very slow CPU)": "Qwen/Qwen3-VL-8B-Instruct", |
|
|
"OpenGVLab/InternVL3_5-8B (very slow CPU)": "OpenGVLab/InternVL3_5-8B", |
|
|
"OpenGVLab/InternVL3_5-38B (not for CPU)": "OpenGVLab/InternVL3_5-38B", |
|
|
"OpenGVLab/InternVL3_5-241B-A28B (not for CPU)": "OpenGVLab/InternVL3_5-241B-A28B", |
|
|
|
|
|
"None (gate only)": None, |
|
|
} |
|
|
|
|
|
|
|
|
_model_cache: Dict[str, Tuple[Any, Any]] = {} |
|
|
|
|
|
import inspect |
|
|
|
|
|
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4"))) |
|
|
|
|
|
def get_vlm(model_id: str): |
|
|
if model_id in _model_cache: |
|
|
return _model_cache[model_id] |
|
|
|
|
|
proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
|
|
|
common_kwargs = dict( |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
sig = inspect.signature(AutoModelForVision2Seq.from_pretrained) |
|
|
if "low_cpu_mem_usage" in sig.parameters: |
|
|
common_kwargs["low_cpu_mem_usage"] = True |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
model = None |
|
|
err = None |
|
|
|
|
|
|
|
|
try: |
|
|
model = AutoModelForVision2Seq.from_pretrained(model_id, **common_kwargs) |
|
|
except Exception as e: |
|
|
err = e |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, **common_kwargs) |
|
|
|
|
|
model.to(DEVICE).eval() |
|
|
_model_cache[model_id] = (proc, model) |
|
|
return proc, model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str: |
|
|
proc, model = get_vlm(model_id) |
|
|
|
|
|
|
|
|
if hasattr(model, "chat") and callable(getattr(model, "chat")): |
|
|
try: |
|
|
|
|
|
|
|
|
return str(model.chat(proc, image, question)).strip() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
conversation = [{ |
|
|
"role": "user", |
|
|
"content": [{"type": "image"}, {"type": "text", "text": question}], |
|
|
}] |
|
|
|
|
|
inputs = None |
|
|
|
|
|
if hasattr(proc, "apply_chat_template"): |
|
|
|
|
|
try: |
|
|
inputs = proc.apply_chat_template( |
|
|
conversation, |
|
|
add_generation_prompt=True, |
|
|
tokenize=True, |
|
|
return_dict=True, |
|
|
return_tensors="pt", |
|
|
images=image, |
|
|
) |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
prompt = proc.apply_chat_template(conversation, add_generation_prompt=True) |
|
|
inputs = proc(text=[prompt], images=[image], return_tensors="pt") |
|
|
except Exception: |
|
|
inputs = None |
|
|
|
|
|
if inputs is None: |
|
|
|
|
|
inputs = proc(text=[question], images=[image], return_tensors="pt") |
|
|
|
|
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")} |
|
|
|
|
|
out = model.generate(**inputs, max_new_tokens=max_new_tokens) |
|
|
text = proc.batch_decode(out, skip_special_tokens=True)[0].strip() |
|
|
|
|
|
|
|
|
if question in text and len(text) > 2 * len(question): |
|
|
text = text.split(question, 1)[-1].strip() |
|
|
return text |
|
|
|
|
|
def cpu_model_allowed(model_id: str) -> Tuple[bool, str]: |
|
|
mid = (model_id or "").lower() |
|
|
blocked = ["72b", "38b", "241b", "a28b"] |
|
|
if any(b in mid for b in blocked): |
|
|
return False, "Too large for CPU Space (will OOM)." |
|
|
return True, "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def choose_resolution(expected: float, probs: list, strategy: str) -> int: |
|
|
if strategy == "expected": |
|
|
return int(round(expected)) |
|
|
if strategy == "argmax": |
|
|
k = int(max(range(len(probs)), key=lambda i: probs[i])) |
|
|
return int(RESOLUTION_MAP[k]) |
|
|
|
|
|
if probs[2] >= 0.34: |
|
|
return int(RESOLUTION_MAP[2]) |
|
|
if probs[1] >= 0.34: |
|
|
return int(RESOLUTION_MAP[1]) |
|
|
return int(RESOLUTION_MAP[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GATE_ADAPTER_REPO = os.getenv("GATE_ADAPTER_REPO", "Kimhi/granite-docling-res-gate-lora") |
|
|
HF_TOKEN = os.getenv("HF_TOKEN", None) |
|
|
GATE_INPUT_MAX_SIDE = int(os.getenv("GATE_INPUT_MAX_SIDE", "256")) |
|
|
|
|
|
gate = None |
|
|
|
|
|
def run(image: Image.Image, question: str, vlm_choice: str, strategy: str): |
|
|
global gate |
|
|
if gate is None: |
|
|
gate = GraniteDoclingGateHF(adapter_repo=GATE_ADAPTER_REPO, token=HF_TOKEN) |
|
|
|
|
|
if image is None or not question: |
|
|
return "Upload an image and enter a question.", None, None |
|
|
|
|
|
native_w, native_h = image.size |
|
|
|
|
|
|
|
|
gate_img = load_and_resize_image(image, GATE_INPUT_MAX_SIDE) |
|
|
probs = gate.predict_probs(gate_img, question) |
|
|
expected = float(sum(RESOLUTION_MAP[i] * probs[i] for i in range(3))) |
|
|
|
|
|
pred = choose_resolution(expected, probs, strategy) |
|
|
|
|
|
|
|
|
native_max = max(native_w, native_h) |
|
|
used_max = min(pred, native_max) |
|
|
|
|
|
resized = load_and_resize_image(image, used_max) |
|
|
resized_w, resized_h = resized.size |
|
|
|
|
|
|
|
|
model_id = DOWNSTREAM_MODELS.get(vlm_choice) |
|
|
if model_id is None: |
|
|
answer = "(gate only) No VLM selected." |
|
|
else: |
|
|
ok, reason = cpu_model_allowed(model_id) |
|
|
if not ok: |
|
|
answer = f"Blocked on CPU: {reason}" |
|
|
else: |
|
|
try: |
|
|
answer = vlm_answer(model_id, resized, question) |
|
|
except Exception as e: |
|
|
answer = f"VLM error: {type(e).__name__}: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if strategy != "expected": |
|
|
info = ( |
|
|
f"Native: {native_w}×{native_h}\n" |
|
|
|
|
|
f"Sufficient max-side: {expected:.1f}\n" |
|
|
f"Strategy: {strategy}\n" |
|
|
f"Predicted sufficient max-side: {pred}\n" |
|
|
f"Used max-side (clamped to native): {used_max}\n" |
|
|
f"Resized sent to VLM: {resized_w}×{resized_h}\n" |
|
|
f"VLM: {vlm_choice}\n" |
|
|
) |
|
|
else: |
|
|
info = ( |
|
|
f"Native: {native_w}×{native_h}\n" |
|
|
|
|
|
f"Sufficient max-side: {expected:.1f}\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f"VLM: {vlm_choice}\n") |
|
|
|
|
|
return info, resized, answer |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# CARES – Sufficient Resolution Selection for VLMs") |
|
|
|
|
|
with gr.Row(): |
|
|
inp_img = gr.Image(type="pil", label="Upload image") |
|
|
with gr.Column(): |
|
|
inp_q = gr.Textbox(label="Question", placeholder="Ask something about the image…") |
|
|
vlm = gr.Dropdown( |
|
|
choices=list(DOWNSTREAM_MODELS.keys()), |
|
|
value=list(DOWNSTREAM_MODELS.keys())[0], |
|
|
label="VLM", |
|
|
) |
|
|
strategy = gr.Dropdown( |
|
|
choices=["expected", "argmax", "conservative"], |
|
|
value="expected", |
|
|
label="Resolution selection strategy", |
|
|
) |
|
|
btn = gr.Button("Run") |
|
|
|
|
|
out_info = gr.Textbox(label="Info", lines=10) |
|
|
out_img = gr.Image(type="pil", label="Image used for inference (sufficient resolution)") |
|
|
out_ans = gr.Textbox(label="Answer", lines=6) |
|
|
|
|
|
btn.click(run, inputs=[inp_img, inp_q, vlm, strategy], outputs=[out_info, out_img, out_ans]) |
|
|
|
|
|
demo.launch(ssr_mode=False) |
|
|
|