Spaces:
Running on Zero
Running on Zero
File size: 6,309 Bytes
aa00509 af9809a aa00509 af9809a aa00509 af9809a aa00509 af9809a aa00509 af9809a aa00509 af9809a aa00509 af9809a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """Inference layer for TemperCheck.
Two interchangeable backends, selected by the TEMPER_BACKEND env var (which
defaults to "transformers" on a Hugging Face Space, "ollama" elsewhere):
- "transformers" (the Hugging Face Space / ZeroGPU) — loads google/gemma-4-E4B-it
with transformers and runs inference inside a @spaces.GPU
function. This is the deployment path; Gemma 4 vision works
here (it is broken in the local Ollama builds — see CLAUDE.md).
- "ollama" (local experimentation) — calls a local Ollama server. Fast and
torch-free, but the local Gemma 4 vision is unreliable, so this
is for plumbing/UI work, not real verdicts.
Everything else in the app talks to `score_image()` and never imports a backend
directly, so the model can be swapped without touching the UI (see CLAUDE.md).
"""
from __future__ import annotations
import base64
import io
import os
from PIL import Image
from .prompt import (
SYSTEM_PROMPT,
USER_INSTRUCTION,
TemperVerdict,
build_messages,
parse_verdict,
)
# On a Space, default to the transformers backend; locally, default to Ollama.
_ON_SPACE = bool(os.environ.get("SPACE_ID"))
BACKEND = os.environ.get(
"TEMPER_BACKEND", "transformers" if _ON_SPACE else "ollama"
).lower()
# Use 127.0.0.1, not "localhost": on Windows the latter resolves to IPv6 ::1
# first and stalls ~2s per request before falling back to IPv4 (measured — it
# was over half the total latency).
OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://127.0.0.1:11434")
OLLAMA_MODEL = os.environ.get(
"TEMPER_OLLAMA_MODEL", "huihui_ai/gemma-4-abliterated:e4b-q8_0"
)
HF_MODEL = os.environ.get("TEMPER_HF_MODEL", "google/gemma-4-E4B-it")
# The verdict JSON is ~80 tokens; cap generation so the model can't ramble.
# Headroom over that keeps the JSON from ever truncating (which would break
# parsing). Generation is the only length-dependent cost — the image is not.
MAX_NEW_TOKENS = 192
# Our prompt is ~500 tokens; left alone the model loads a 128K context and pays
# the setup cost for it every request. A small context removes most of that.
OLLAMA_NUM_CTX = 4096
# Pin the model in VRAM between requests so there's no reload on the first hit
# after an idle gap. -1 = never unload.
OLLAMA_KEEP_ALIVE = -1
# Reuse one HTTP connection across requests (cheap; avoids per-call TCP setup).
_session = None
def _get_session():
global _session
if _session is None:
import requests
_session = requests.Session()
return _session
def get_backend_name() -> str:
if BACKEND == "transformers":
return f"transformers · {HF_MODEL}"
return f"ollama · {OLLAMA_MODEL}"
def _to_png_bytes(image: Image.Image) -> bytes:
buf = io.BytesIO()
image.convert("RGB").save(buf, format="PNG")
return buf.getvalue()
# --- Ollama backend ---------------------------------------------------------
def _score_ollama(image: Image.Image) -> str:
b64 = base64.b64encode(_to_png_bytes(image)).decode("ascii")
payload = {
"model": OLLAMA_MODEL,
"format": "json", # ask Ollama to constrain output to JSON
"stream": False,
"keep_alive": OLLAMA_KEEP_ALIVE,
# NOTE: do NOT set num_predict here — with this model + format:json on
# Ollama 0.30.7 it returns an empty completion (measured). The JSON
# format already stops generation when the object closes, so a cap is
# unnecessary. num_predict/max_new_tokens applies to the HF path only.
"options": {
"temperature": 0.4,
"num_ctx": OLLAMA_NUM_CTX,
},
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_INSTRUCTION, "images": [b64]},
],
}
resp = _get_session().post(f"{OLLAMA_HOST}/api/chat", json=payload, timeout=180)
resp.raise_for_status()
return resp.json()["message"]["content"]
# --- Transformers backend (Hugging Face Space / ZeroGPU) --------------------
#
# ZeroGPU rules (https://huggingface.co/docs/hub/spaces-zerogpu):
# * import `spaces` before torch,
# * place the model on `cuda` at MODULE level (a CUDA emulation mode makes this
# work at startup; lazy-loading inside the GPU fn is slower and discouraged),
# * decorate the inference fn with @spaces.GPU (a no-op off ZeroGPU).
# All of this is set up only when the transformers backend is active, so local
# Ollama work needs neither torch nor spaces installed.
def _build_transformers_scorer():
import spaces # noqa: F401 (import before torch on ZeroGPU)
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
processor = AutoProcessor.from_pretrained(HF_MODEL)
model = AutoModelForImageTextToText.from_pretrained(HF_MODEL, dtype="auto")
model.eval()
model.to("cuda") # placed at module level per ZeroGPU; realized in the GPU fn
@spaces.GPU(duration=90)
def _score(image: Image.Image) -> str:
messages = build_messages(image.convert("RGB"))
inputs = processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to("cuda")
input_len = inputs["input_ids"].shape[-1]
with torch.inference_mode():
out = model.generate(
**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False
)
return processor.decode(out[0][input_len:], skip_special_tokens=True)
return _score
if BACKEND == "transformers":
_score_transformers = _build_transformers_scorer()
else:
def _score_transformers(image: Image.Image) -> str:
raise RuntimeError("transformers backend is not active")
# --- Public API -------------------------------------------------------------
def score_image(image: Image.Image) -> TemperVerdict:
"""Run the configured backend on a PIL image and return a parsed verdict."""
if image is None:
raise ValueError("No image provided.")
text = _score_transformers(image) if BACKEND == "transformers" else _score_ollama(
image
)
return parse_verdict(text)
|