File size: 6,309 Bytes
aa00509
af9809a
aa00509
af9809a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa00509
af9809a
 
 
 
 
 
 
aa00509
af9809a
 
 
 
 
 
 
aa00509
af9809a
aa00509
af9809a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa00509
af9809a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Inference layer for TemperCheck.

Two interchangeable backends, selected by the TEMPER_BACKEND env var (which
defaults to "transformers" on a Hugging Face Space, "ollama" elsewhere):

  - "transformers" (the Hugging Face Space / ZeroGPU) — loads google/gemma-4-E4B-it
                   with transformers and runs inference inside a @spaces.GPU
                   function. This is the deployment path; Gemma 4 vision works
                   here (it is broken in the local Ollama builds — see CLAUDE.md).
  - "ollama"       (local experimentation) — calls a local Ollama server. Fast and
                   torch-free, but the local Gemma 4 vision is unreliable, so this
                   is for plumbing/UI work, not real verdicts.

Everything else in the app talks to `score_image()` and never imports a backend
directly, so the model can be swapped without touching the UI (see CLAUDE.md).
"""

from __future__ import annotations

import base64
import io
import os

from PIL import Image

from .prompt import (
    SYSTEM_PROMPT,
    USER_INSTRUCTION,
    TemperVerdict,
    build_messages,
    parse_verdict,
)

# On a Space, default to the transformers backend; locally, default to Ollama.
_ON_SPACE = bool(os.environ.get("SPACE_ID"))
BACKEND = os.environ.get(
    "TEMPER_BACKEND", "transformers" if _ON_SPACE else "ollama"
).lower()

# Use 127.0.0.1, not "localhost": on Windows the latter resolves to IPv6 ::1
# first and stalls ~2s per request before falling back to IPv4 (measured — it
# was over half the total latency).
OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://127.0.0.1:11434")
OLLAMA_MODEL = os.environ.get(
    "TEMPER_OLLAMA_MODEL", "huihui_ai/gemma-4-abliterated:e4b-q8_0"
)
HF_MODEL = os.environ.get("TEMPER_HF_MODEL", "google/gemma-4-E4B-it")

# The verdict JSON is ~80 tokens; cap generation so the model can't ramble.
# Headroom over that keeps the JSON from ever truncating (which would break
# parsing). Generation is the only length-dependent cost — the image is not.
MAX_NEW_TOKENS = 192
# Our prompt is ~500 tokens; left alone the model loads a 128K context and pays
# the setup cost for it every request. A small context removes most of that.
OLLAMA_NUM_CTX = 4096
# Pin the model in VRAM between requests so there's no reload on the first hit
# after an idle gap. -1 = never unload.
OLLAMA_KEEP_ALIVE = -1

# Reuse one HTTP connection across requests (cheap; avoids per-call TCP setup).
_session = None


def _get_session():
    global _session
    if _session is None:
        import requests

        _session = requests.Session()
    return _session


def get_backend_name() -> str:
    if BACKEND == "transformers":
        return f"transformers · {HF_MODEL}"
    return f"ollama · {OLLAMA_MODEL}"


def _to_png_bytes(image: Image.Image) -> bytes:
    buf = io.BytesIO()
    image.convert("RGB").save(buf, format="PNG")
    return buf.getvalue()


# --- Ollama backend ---------------------------------------------------------


def _score_ollama(image: Image.Image) -> str:
    b64 = base64.b64encode(_to_png_bytes(image)).decode("ascii")
    payload = {
        "model": OLLAMA_MODEL,
        "format": "json",  # ask Ollama to constrain output to JSON
        "stream": False,
        "keep_alive": OLLAMA_KEEP_ALIVE,
        # NOTE: do NOT set num_predict here — with this model + format:json on
        # Ollama 0.30.7 it returns an empty completion (measured). The JSON
        # format already stops generation when the object closes, so a cap is
        # unnecessary. num_predict/max_new_tokens applies to the HF path only.
        "options": {
            "temperature": 0.4,
            "num_ctx": OLLAMA_NUM_CTX,
        },
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": USER_INSTRUCTION, "images": [b64]},
        ],
    }
    resp = _get_session().post(f"{OLLAMA_HOST}/api/chat", json=payload, timeout=180)
    resp.raise_for_status()
    return resp.json()["message"]["content"]


# --- Transformers backend (Hugging Face Space / ZeroGPU) --------------------
#
# ZeroGPU rules (https://huggingface.co/docs/hub/spaces-zerogpu):
#   * import `spaces` before torch,
#   * place the model on `cuda` at MODULE level (a CUDA emulation mode makes this
#     work at startup; lazy-loading inside the GPU fn is slower and discouraged),
#   * decorate the inference fn with @spaces.GPU (a no-op off ZeroGPU).
# All of this is set up only when the transformers backend is active, so local
# Ollama work needs neither torch nor spaces installed.


def _build_transformers_scorer():
    import spaces  # noqa: F401  (import before torch on ZeroGPU)
    import torch
    from transformers import AutoModelForImageTextToText, AutoProcessor

    processor = AutoProcessor.from_pretrained(HF_MODEL)
    model = AutoModelForImageTextToText.from_pretrained(HF_MODEL, dtype="auto")
    model.eval()
    model.to("cuda")  # placed at module level per ZeroGPU; realized in the GPU fn

    @spaces.GPU(duration=90)
    def _score(image: Image.Image) -> str:
        messages = build_messages(image.convert("RGB"))
        inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            add_generation_prompt=True,
        ).to("cuda")
        input_len = inputs["input_ids"].shape[-1]
        with torch.inference_mode():
            out = model.generate(
                **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False
            )
        return processor.decode(out[0][input_len:], skip_special_tokens=True)

    return _score


if BACKEND == "transformers":
    _score_transformers = _build_transformers_scorer()
else:

    def _score_transformers(image: Image.Image) -> str:
        raise RuntimeError("transformers backend is not active")


# --- Public API -------------------------------------------------------------


def score_image(image: Image.Image) -> TemperVerdict:
    """Run the configured backend on a PIL image and return a parsed verdict."""
    if image is None:
        raise ValueError("No image provided.")
    text = _score_transformers(image) if BACKEND == "transformers" else _score_ollama(
        image
    )
    return parse_verdict(text)