TemperCheck / tempercheck /inference.py
Joseph Antolick
Rename Crankycheck -> TemperCheck
aa00509
Raw
History Blame Contribute Delete
6.31 kB
"""Inference layer for TemperCheck.
Two interchangeable backends, selected by the TEMPER_BACKEND env var (which
defaults to "transformers" on a Hugging Face Space, "ollama" elsewhere):
- "transformers" (the Hugging Face Space / ZeroGPU) — loads google/gemma-4-E4B-it
with transformers and runs inference inside a @spaces.GPU
function. This is the deployment path; Gemma 4 vision works
here (it is broken in the local Ollama builds — see CLAUDE.md).
- "ollama" (local experimentation) — calls a local Ollama server. Fast and
torch-free, but the local Gemma 4 vision is unreliable, so this
is for plumbing/UI work, not real verdicts.
Everything else in the app talks to `score_image()` and never imports a backend
directly, so the model can be swapped without touching the UI (see CLAUDE.md).
"""
from __future__ import annotations
import base64
import io
import os
from PIL import Image
from .prompt import (
SYSTEM_PROMPT,
USER_INSTRUCTION,
TemperVerdict,
build_messages,
parse_verdict,
)
# On a Space, default to the transformers backend; locally, default to Ollama.
_ON_SPACE = bool(os.environ.get("SPACE_ID"))
BACKEND = os.environ.get(
"TEMPER_BACKEND", "transformers" if _ON_SPACE else "ollama"
).lower()
# Use 127.0.0.1, not "localhost": on Windows the latter resolves to IPv6 ::1
# first and stalls ~2s per request before falling back to IPv4 (measured — it
# was over half the total latency).
OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://127.0.0.1:11434")
OLLAMA_MODEL = os.environ.get(
"TEMPER_OLLAMA_MODEL", "huihui_ai/gemma-4-abliterated:e4b-q8_0"
)
HF_MODEL = os.environ.get("TEMPER_HF_MODEL", "google/gemma-4-E4B-it")
# The verdict JSON is ~80 tokens; cap generation so the model can't ramble.
# Headroom over that keeps the JSON from ever truncating (which would break
# parsing). Generation is the only length-dependent cost — the image is not.
MAX_NEW_TOKENS = 192
# Our prompt is ~500 tokens; left alone the model loads a 128K context and pays
# the setup cost for it every request. A small context removes most of that.
OLLAMA_NUM_CTX = 4096
# Pin the model in VRAM between requests so there's no reload on the first hit
# after an idle gap. -1 = never unload.
OLLAMA_KEEP_ALIVE = -1
# Reuse one HTTP connection across requests (cheap; avoids per-call TCP setup).
_session = None
def _get_session():
global _session
if _session is None:
import requests
_session = requests.Session()
return _session
def get_backend_name() -> str:
if BACKEND == "transformers":
return f"transformers · {HF_MODEL}"
return f"ollama · {OLLAMA_MODEL}"
def _to_png_bytes(image: Image.Image) -> bytes:
buf = io.BytesIO()
image.convert("RGB").save(buf, format="PNG")
return buf.getvalue()
# --- Ollama backend ---------------------------------------------------------
def _score_ollama(image: Image.Image) -> str:
b64 = base64.b64encode(_to_png_bytes(image)).decode("ascii")
payload = {
"model": OLLAMA_MODEL,
"format": "json", # ask Ollama to constrain output to JSON
"stream": False,
"keep_alive": OLLAMA_KEEP_ALIVE,
# NOTE: do NOT set num_predict here — with this model + format:json on
# Ollama 0.30.7 it returns an empty completion (measured). The JSON
# format already stops generation when the object closes, so a cap is
# unnecessary. num_predict/max_new_tokens applies to the HF path only.
"options": {
"temperature": 0.4,
"num_ctx": OLLAMA_NUM_CTX,
},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_INSTRUCTION, "images": [b64]},
],
}
resp = _get_session().post(f"{OLLAMA_HOST}/api/chat", json=payload, timeout=180)
resp.raise_for_status()
return resp.json()["message"]["content"]
# --- Transformers backend (Hugging Face Space / ZeroGPU) --------------------
#
# ZeroGPU rules (https://huggingface.co/docs/hub/spaces-zerogpu):
# * import `spaces` before torch,
# * place the model on `cuda` at MODULE level (a CUDA emulation mode makes this
# work at startup; lazy-loading inside the GPU fn is slower and discouraged),
# * decorate the inference fn with @spaces.GPU (a no-op off ZeroGPU).
# All of this is set up only when the transformers backend is active, so local
# Ollama work needs neither torch nor spaces installed.
def _build_transformers_scorer():
import spaces # noqa: F401 (import before torch on ZeroGPU)
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
processor = AutoProcessor.from_pretrained(HF_MODEL)
model = AutoModelForImageTextToText.from_pretrained(HF_MODEL, dtype="auto")
model.eval()
model.to("cuda") # placed at module level per ZeroGPU; realized in the GPU fn
@spaces.GPU(duration=90)
def _score(image: Image.Image) -> str:
messages = build_messages(image.convert("RGB"))
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to("cuda")
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
out = model.generate(
**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False
)
return processor.decode(out[0][input_len:], skip_special_tokens=True)
return _score
if BACKEND == "transformers":
_score_transformers = _build_transformers_scorer()
else:
def _score_transformers(image: Image.Image) -> str:
raise RuntimeError("transformers backend is not active")
# --- Public API -------------------------------------------------------------
def score_image(image: Image.Image) -> TemperVerdict:
"""Run the configured backend on a PIL image and return a parsed verdict."""
if image is None:
raise ValueError("No image provided.")
text = _score_transformers(image) if BACKEND == "transformers" else _score_ollama(
image
)
return parse_verdict(text)