CARES / app.py
Kimhi's picture
Disabling Server-Side Rendering (SSR)
57581f7 verified
import os
import dataclasses
from typing import Optional, Dict, Any, Tuple
import torch
import torch.nn.functional as F
import gradio as gr
from PIL import Image
from transformers import (
AutoProcessor,
AutoModelForVision2Seq,
AutoModelForCausalLM,
)
from peft import PeftModel, PeftConfig
# -----------------------------
# CPU-only enforcement
# -----------------------------
FORCE_CPU = True
DEVICE = torch.device("cpu")
# -----------------------------
# Resolution gate
# -----------------------------
RESOLUTION_MAP = {0: 384, 1: 768, 2: 1024}
def load_and_resize_image(img: Image.Image, max_size: Optional[int] = None) -> Image.Image:
img = img.convert("RGB")
if max_size is None:
return img
w, h = img.size
if max(w, h) <= max_size:
return img
s = max_size / max(w, h)
return img.resize((round(w * s), round(h * s)), Image.BICUBIC)
def token_id_for_digit(tokenizer, digit: str) -> int:
ids = tokenizer.encode(digit, add_special_tokens=False)
if not ids:
raise ValueError(f"Could not encode digit {digit!r}")
return ids[-1]
class GraniteDoclingGateHF:
def __init__(self, adapter_repo: str, token: Optional[str] = None):
self.device = DEVICE
peft_cfg = PeftConfig.from_pretrained(adapter_repo, token=token)
base_model_name = peft_cfg.base_model_name_or_path
self.processor = AutoProcessor.from_pretrained(adapter_repo, token=token)
# CPU: use float32 for safety (bfloat16/float16 often slower or problematic on CPU)
torch_dtype = torch.float32
base_model = AutoModelForVision2Seq.from_pretrained(
base_model_name, torch_dtype=torch_dtype
)
self.model = PeftModel.from_pretrained(base_model, adapter_repo, token=token)
self.model.to(self.device).eval()
tok = self.processor.tokenizer
self.class_token_ids = [
token_id_for_digit(tok, "0"),
token_id_for_digit(tok, "1"),
token_id_for_digit(tok, "2"),
]
@torch.no_grad()
def predict_probs(self, image: Image.Image, question: str):
messages = [
{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}
]
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = self.processor(text=[prompt], images=[image], return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items() if hasattr(v, "to")}
outputs = self.model(**inputs)
next_token_logits = outputs.logits[:, -1, :]
class_logits = next_token_logits[:, self.class_token_ids]
probs = F.softmax(class_logits, dim=-1)[0].detach().float().cpu().tolist()
return probs
def predict_expected(self, image: Image.Image, question: str) -> float:
probs = self.predict_probs(image, question)
return float(sum(RESOLUTION_MAP[i] * probs[i] for i in range(3)))
# -----------------------------
# CPU-friendly downstream HF VLM inference
# -----------------------------
# IMPORTANT: Choose models that can run on CPU.
# Many VLMs will be too slow/heavy on CPU; start small.
DOWNSTREAM_MODELS = {
"ibm-granite/granite-vision-3.3-2b (recommended CPU)": "ibm-granite/granite-vision-3.3-2b",
"HuggingFaceTB/SmolVLM-256M-Instruct (tiny CPU)": "HuggingFaceTB/SmolVLM-256M-Instruct",
"google/paligemma-3b-mix-224 (CPU)": "google/paligemma-3b-mix-224",
# Your list (kept available but not recommended on CPU)
"Qwen/Qwen2.5-VL-3B-Instruct (slow CPU)": "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2.5-VL-7B-Instruct (very slow CPU)": "Qwen/Qwen2.5-VL-7B-Instruct",
"Qwen/Qwen2.5-VL-72B-Instruct (not for CPU)": "Qwen/Qwen2.5-VL-72B-Instruct",
"Qwen/Qwen3-VL-8B-Instruct (very slow CPU)": "Qwen/Qwen3-VL-8B-Instruct",
"OpenGVLab/InternVL3_5-8B (very slow CPU)": "OpenGVLab/InternVL3_5-8B",
"OpenGVLab/InternVL3_5-38B (not for CPU)": "OpenGVLab/InternVL3_5-38B",
"OpenGVLab/InternVL3_5-241B-A28B (not for CPU)": "OpenGVLab/InternVL3_5-241B-A28B",
"None (gate only)": None,
}
_model_cache: Dict[str, Tuple[Any, Any]] = {}
import inspect
torch.set_num_threads(int(os.getenv("TORCH_NUM_THREADS", "4")))
def get_vlm(model_id: str):
if model_id in _model_cache:
return _model_cache[model_id]
proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
common_kwargs = dict(
torch_dtype=torch.float32,
trust_remote_code=True,
)
# low_cpu_mem_usage exists on many HF models; use it if supported
try:
sig = inspect.signature(AutoModelForVision2Seq.from_pretrained)
if "low_cpu_mem_usage" in sig.parameters:
common_kwargs["low_cpu_mem_usage"] = True
except Exception:
pass
model = None
err = None
# Try Vision2Seq first
try:
model = AutoModelForVision2Seq.from_pretrained(model_id, **common_kwargs)
except Exception as e:
err = e
# Fallback: CausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, **common_kwargs)
model.to(DEVICE).eval()
_model_cache[model_id] = (proc, model)
return proc, model
# def get_vlm(model_id: str):
# if model_id in _model_cache:
# return _model_cache[model_id]
# # CPU-only: float32 and no device_map
# proc = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
# try:
# model = AutoModelForVision2Seq.from_pretrained(
# model_id,
# torch_dtype=torch.float32,
# trust_remote_code=True,
# )
# except Exception:
# model = AutoModelForCausalLM.from_pretrained(
# model_id,
# torch_dtype=torch.float32,
# trust_remote_code=True,
# )
# model.to(DEVICE).eval()
# _model_cache[model_id] = (proc, model)
# return proc, model
# @torch.no_grad()
# def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str:
# proc, model = get_vlm(model_id)
# messages = [
# {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}
# ]
# if hasattr(proc, "apply_chat_template"):
# prompt = proc.apply_chat_template(messages, add_generation_prompt=True)
# inputs = proc(text=[prompt], images=[image], return_tensors="pt")
# else:
# inputs = proc(text=[question], images=[image], return_tensors="pt")
# inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")}
# out = model.generate(**inputs, max_new_tokens=max_new_tokens)
# text = proc.batch_decode(out, skip_special_tokens=True)[0].strip()
# # Heuristic: remove prompt echoes
# if question in text and len(text) > 2 * len(question):
# text = text.split(question, 1)[-1].strip()
# return text
##GV ONLY
# @torch.no_grad()
# def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str:
# proc, model = get_vlm(model_id)
# conversation = [
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": question},
# ],
# }
# ]
# # Prefer the Granite-style path if supported
# if hasattr(proc, "apply_chat_template"):
# try:
# inputs = proc.apply_chat_template(
# conversation,
# add_generation_prompt=True,
# tokenize=True,
# return_dict=True,
# return_tensors="pt",
# images=image, # some processors accept this; if not, except below
# )
# except TypeError:
# # Fallback: build prompt then call processor(text, images)
# prompt = proc.apply_chat_template(conversation, add_generation_prompt=True)
# inputs = proc(text=[prompt], images=[image], return_tensors="pt")
# else:
# inputs = proc(text=[question], images=[image], return_tensors="pt")
# inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")}
# out = model.generate(**inputs, max_new_tokens=max_new_tokens)
# text = proc.batch_decode(out, skip_special_tokens=True)[0].strip()
# if question in text and len(text) > 2 * len(question):
# text = text.split(question, 1)[-1].strip()
# return text
@torch.no_grad()
def vlm_answer(model_id: str, image: Image.Image, question: str, max_new_tokens: int = 96) -> str:
proc, model = get_vlm(model_id)
# ---- Path A: model.chat (InternVL-style, some others) ----
if hasattr(model, "chat") and callable(getattr(model, "chat")):
try:
# Different repos have different signatures; this is the most common pattern.
# If it fails, we fall back to processor+generate.
return str(model.chat(proc, image, question)).strip()
except Exception:
pass
# ---- Path B: processor + generate ----
conversation = [{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": question}],
}]
inputs = None
if hasattr(proc, "apply_chat_template"):
# Try Granite-style “tokenize=True” path
try:
inputs = proc.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
images=image, # supported by some processors
)
except Exception:
# Fallback: create prompt string then call processor
try:
prompt = proc.apply_chat_template(conversation, add_generation_prompt=True)
inputs = proc(text=[prompt], images=[image], return_tensors="pt")
except Exception:
inputs = None
if inputs is None:
# Final fallback: no templates
inputs = proc(text=[question], images=[image], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items() if hasattr(v, "to")}
out = model.generate(**inputs, max_new_tokens=max_new_tokens)
text = proc.batch_decode(out, skip_special_tokens=True)[0].strip()
# prompt-echo cleanup (best effort)
if question in text and len(text) > 2 * len(question):
text = text.split(question, 1)[-1].strip()
return text
def cpu_model_allowed(model_id: str) -> Tuple[bool, str]:
mid = (model_id or "").lower()
blocked = ["72b", "38b", "241b", "a28b"]
if any(b in mid for b in blocked):
return False, "Too large for CPU Space (will OOM)."
return True, ""
# -----------------------------
# Resolution selection strategy
# -----------------------------
def choose_resolution(expected: float, probs: list, strategy: str) -> int:
if strategy == "expected":
return int(round(expected))
if strategy == "argmax":
k = int(max(range(len(probs)), key=lambda i: probs[i]))
return int(RESOLUTION_MAP[k])
# conservative: choose highest bucket if it has meaningful mass, else next, else lowest
if probs[2] >= 0.34:
return int(RESOLUTION_MAP[2])
if probs[1] >= 0.34:
return int(RESOLUTION_MAP[1])
return int(RESOLUTION_MAP[0])
# -----------------------------
# Gradio app
# -----------------------------
GATE_ADAPTER_REPO = os.getenv("GATE_ADAPTER_REPO", "Kimhi/granite-docling-res-gate-lora")
HF_TOKEN = os.getenv("HF_TOKEN", None)
GATE_INPUT_MAX_SIDE = int(os.getenv("GATE_INPUT_MAX_SIDE", "256"))
gate = None
def run(image: Image.Image, question: str, vlm_choice: str, strategy: str):
global gate
if gate is None:
gate = GraniteDoclingGateHF(adapter_repo=GATE_ADAPTER_REPO, token=HF_TOKEN)
if image is None or not question:
return "Upload an image and enter a question.", None, None
native_w, native_h = image.size
# Gate runs on small image
gate_img = load_and_resize_image(image, GATE_INPUT_MAX_SIDE)
probs = gate.predict_probs(gate_img, question)
expected = float(sum(RESOLUTION_MAP[i] * probs[i] for i in range(3)))
pred = choose_resolution(expected, probs, strategy)
# never upscale above native max-side
native_max = max(native_w, native_h)
used_max = min(pred, native_max)
resized = load_and_resize_image(image, used_max)
resized_w, resized_h = resized.size
model_id = DOWNSTREAM_MODELS.get(vlm_choice)
if model_id is None:
answer = "(gate only) No VLM selected."
else:
ok, reason = cpu_model_allowed(model_id)
if not ok:
answer = f"Blocked on CPU: {reason}"
else:
try:
answer = vlm_answer(model_id, resized, question)
except Exception as e:
answer = f"VLM error: {type(e).__name__}: {e}"
# model_id = DOWNSTREAM_MODELS.get(vlm_choice)
# if model_id is None:
# answer = "(gate only) No VLM selected."
# else:
# answer = vlm_answer(model_id, resized, question)
if strategy != "expected":
info = (
f"Native: {native_w}×{native_h}\n"
#f"Gate probs [384,768,1024]: {['%.3f'%p for p in probs]}\n"
f"Sufficient max-side: {expected:.1f}\n"
f"Strategy: {strategy}\n"
f"Predicted sufficient max-side: {pred}\n"
f"Used max-side (clamped to native): {used_max}\n"
f"Resized sent to VLM: {resized_w}×{resized_h}\n"
f"VLM: {vlm_choice}\n"
)
else:
info = (
f"Native: {native_w}×{native_h}\n"
#f"Gate probs [384,768,1024]: {['%.3f'%p for p in probs]}\n"
f"Sufficient max-side: {expected:.1f}\n"
#f"Strategy: {strategy}\n"
#f"Predicted sufficient max-side: {pred}\n"
#f"Used max-side (clamped to native): {used_max}\n"
#f"Resized sent to VLM: {resized_w}×{resized_h}\n"
f"VLM: {vlm_choice}\n")
return info, resized, answer
with gr.Blocks() as demo:
gr.Markdown("# CARES – Sufficient Resolution Selection for VLMs")
with gr.Row():
inp_img = gr.Image(type="pil", label="Upload image")
with gr.Column():
inp_q = gr.Textbox(label="Question", placeholder="Ask something about the image…")
vlm = gr.Dropdown(
choices=list(DOWNSTREAM_MODELS.keys()),
value=list(DOWNSTREAM_MODELS.keys())[0],
label="VLM",
)
strategy = gr.Dropdown(
choices=["expected", "argmax", "conservative"],
value="expected",
label="Resolution selection strategy",
)
btn = gr.Button("Run")
out_info = gr.Textbox(label="Info", lines=10)
out_img = gr.Image(type="pil", label="Image used for inference (sufficient resolution)")
out_ans = gr.Textbox(label="Answer", lines=6)
btn.click(run, inputs=[inp_img, inp_q, vlm, strategy], outputs=[out_info, out_img, out_ans])
demo.launch(ssr_mode=False)