chmielvu's picture
Fix runtime: gradio 5 + bind 0.0.0.0
8fab742 verified
from __future__ import annotations
import base64
import hashlib
import io
import json
import os
from dataclasses import dataclass
from typing import Any
import gradio as gr
import numpy as np
from PIL import Image
# ---- Model (SigLIP 768d) ---------------------------------------------------
SIGLIP_MODEL_ID = "google/siglip-base-patch16-224"
@dataclass(frozen=True)
class _Embedder:
processor: Any
model: Any
_EMBEDDER: _Embedder | None = None
def _get_embedder() -> _Embedder:
global _EMBEDDER
if _EMBEDDER is not None:
return _EMBEDDER
import torch
from transformers import AutoProcessor, AutoModel
processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
model = AutoModel.from_pretrained(SIGLIP_MODEL_ID)
model.eval()
torch.set_grad_enabled(False)
_EMBEDDER = _Embedder(processor=processor, model=model)
return _EMBEDDER
def _to_pil(x: Any) -> Image.Image:
if isinstance(x, Image.Image):
return x
if isinstance(x, dict) and isinstance(x.get("path"), str):
return Image.open(x["path"]).convert("RGBA")
if isinstance(x, str):
return Image.open(x).convert("RGBA")
raise TypeError(f"Unsupported image input: {type(x).__name__}")
def _sha256_bytes(b: bytes) -> str:
return hashlib.sha256(b).hexdigest()
def _sha256_image(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return _sha256_bytes(buf.getvalue())
def _l2_normalize(v: np.ndarray) -> np.ndarray:
n = np.linalg.norm(v, axis=-1, keepdims=True)
n = np.maximum(n, 1e-12)
return v / n
def _embed_pils(pils: list[Image.Image]) -> list[dict[str, Any]]:
import torch
emb = _get_embedder()
inputs = emb.processor(images=[p.convert("RGB") for p in pils], return_tensors="pt")
with torch.no_grad():
# SigLIP-style models expose get_image_features on the multi-modal wrapper.
if hasattr(emb.model, "get_image_features"):
feats = emb.model.get_image_features(**inputs)
else:
out = emb.model(**inputs)
feats = getattr(out, "pooler_output", None) or out.last_hidden_state[:, 0, :]
feats = feats.detach().cpu().numpy().astype("float32")
feats = _l2_normalize(feats)
out: list[dict[str, Any]] = []
for p, vec in zip(pils, feats):
out.append(
{
"dims": int(vec.shape[0]),
"norm": "l2",
"model_id": SIGLIP_MODEL_ID,
"sha256": _sha256_image(p),
"vector": vec.tolist(),
}
)
return out
# ---- Metrics / Heuristics ---------------------------------------------------
def _dhash(img: Image.Image, size: int = 8) -> str:
g = img.convert("L").resize((size + 1, size), Image.BILINEAR)
a = np.asarray(g, dtype=np.int16)
diff = a[:, 1:] > a[:, :-1]
bits = "".join("1" if x else "0" for x in diff.flatten().tolist())
return hex(int(bits, 2))[2:].rjust(size * size // 4, "0")
def _laplacian_var(img: Image.Image) -> float:
g = img.convert("L")
a = np.asarray(g, dtype=np.float32)
k = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
# simple conv2d valid
h, w = a.shape
if h < 3 or w < 3:
return 0.0
out = (
a[1 : h - 1, 0 : w - 2] * k[1, 0]
+ a[0 : h - 2, 1 : w - 1] * k[0, 1]
+ a[1 : h - 1, 1 : w - 1] * k[1, 1]
+ a[2:h, 1 : w - 1] * k[2, 1]
+ a[1 : h - 1, 2:w] * k[1, 2]
)
return float(np.var(out))
def image_metrics(image: Any) -> str:
img = _to_pil(image)
arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0
has_alpha = img.mode in ("RGBA", "LA")
alpha_cov = 1.0
if has_alpha:
a = np.asarray(img.split()[-1], dtype=np.float32) / 255.0
alpha_cov = float(np.mean(a > 0.05))
metrics = {
"width": img.width,
"height": img.height,
"blur_laplacian_var": _laplacian_var(img),
"contrast_std": float(np.std(arr)),
"mean_brightness": float(np.mean(arr)),
"dhash": _dhash(img),
"has_alpha": bool(has_alpha),
"alpha_coverage": alpha_cov,
"sha256": _sha256_image(img),
}
return json.dumps(metrics)
# ---- VLM prep (OpenAI image_url data URL) ----------------------------------
def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image:
max_side = int(max_side)
if max_side <= 0:
return img
w, h = img.size
m = max(w, h)
if m <= max_side:
return img
scale = max_side / float(m)
nw = max(1, int(round(w * scale)))
nh = max(1, int(round(h * scale)))
return img.resize((nw, nh), Image.LANCZOS)
def prepare_for_openai_vlm(image: Any, max_side: int = 768, fmt: str = "webp", quality: int = 85) -> str:
img = _to_pil(image)
img = _resize_max_side(img, max_side=max_side)
fmt = (fmt or "webp").lower()
quality = int(quality)
buf = io.BytesIO()
mime = "image/webp"
if fmt == "jpeg" or fmt == "jpg":
mime = "image/jpeg"
img.convert("RGB").save(buf, format="JPEG", quality=quality, optimize=True)
elif fmt == "png":
mime = "image/png"
img.save(buf, format="PNG", optimize=True)
else:
mime = "image/webp"
img.convert("RGB").save(buf, format="WEBP", quality=quality, method=6)
b = buf.getvalue()
url = f"data:{mime};base64," + base64.b64encode(b).decode("ascii")
out = {
"url": url,
"mime": mime,
"width": img.width,
"height": img.height,
"sha256": _sha256_bytes(b),
}
return json.dumps(out)
def prepare_for_openai_vlm_batch(images: list[Any], max_side: int = 768, fmt: str = "webp", quality: int = 85) -> str:
out = []
for x in images or []:
out.append(json.loads(prepare_for_openai_vlm(x, max_side=max_side, fmt=fmt, quality=quality)))
return json.dumps(out)
# ---- Background removal + alpha trim ----------------------------------------
def bg_remove(image: Any) -> tuple[str, str]:
from rembg import remove
img = _to_pil(image).convert("RGBA")
buf = io.BytesIO()
img.save(buf, format="PNG")
out_bytes = remove(buf.getvalue())
# Write to a temp file Gradio can serve
out_path = "bg_removed.png"
with open(out_path, "wb") as f:
f.write(out_bytes)
meta = {"method": "rembg", "sha256_in": _sha256_image(img), "sha256_out": _sha256_bytes(out_bytes)}
return out_path, json.dumps(meta)
def trim_alpha(image: Any) -> tuple[str, str]:
img = _to_pil(image).convert("RGBA")
a = np.asarray(img.split()[-1], dtype=np.uint8)
ys, xs = np.where(a > 0)
if len(xs) == 0 or len(ys) == 0:
out_path = "trimmed.png"
img.save(out_path, format="PNG")
meta = {"bbox": [0, 0, img.width, img.height], "orig_size": [img.width, img.height]}
return out_path, json.dumps(meta)
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
# inclusive -> size
w = x1 - x0 + 1
h = y1 - y0 + 1
cropped = img.crop((x0, y0, x0 + w, y0 + h))
out_path = "trimmed.png"
cropped.save(out_path, format="PNG")
meta = {"bbox": [x0, y0, w, h], "orig_size": [img.width, img.height]}
return out_path, json.dumps(meta)
# ---- Spritesheet packing ----------------------------------------------------
def pack_spritesheet(images: list[Any], names_json: str) -> tuple[str, str]:
names = []
try:
names = json.loads(names_json or "[]")
except Exception:
names = []
if not isinstance(names, list):
names = []
pils = [_to_pil(x).convert("RGBA") for x in (images or [])]
if not pils:
return "", json.dumps({"error": "no_images"})
# Simple grid packer: fixed columns, max cell size per image.
cols = min(4, len(pils))
rows = int(np.ceil(len(pils) / cols))
cell_w = max(p.width for p in pils)
cell_h = max(p.height for p in pils)
sheet = Image.new("RGBA", (cell_w * cols, cell_h * rows), (0, 0, 0, 0))
mapping: dict[str, Any] = {"cell": [cell_w, cell_h], "items": {}}
for i, p in enumerate(pils):
r = i // cols
c = i % cols
x = c * cell_w
y = r * cell_h
sheet.alpha_composite(p, (x, y))
key = str(names[i]) if i < len(names) else f"item_{i}"
mapping["items"][key] = {"x": x, "y": y, "w": p.width, "h": p.height}
out_path = "spritesheet.png"
sheet.save(out_path, format="PNG")
return out_path, json.dumps(mapping)
# ---- Public endpoints -------------------------------------------------------
def health() -> str:
return json.dumps({"ok": True, "embed_dims": 768, "model_id": SIGLIP_MODEL_ID})
def embed_images_batch(images: list[Any]) -> str:
pils = [_to_pil(x) for x in (images or [])]
out = _embed_pils(pils)
return json.dumps(out)
with gr.Blocks() as demo:
gr.Markdown("# Image Processing Service")
with gr.Tab("API"):
inp = gr.File(label="Image", file_types=["image"])
max_side = gr.Slider(128, 2048, value=768, step=64, label="max_side (VLM prep)")
fmt = gr.Dropdown(["webp", "jpeg", "png"], value="webp", label="format")
quality = gr.Slider(10, 100, value=85, step=1, label="quality")
out_json = gr.Code(language="json", label="Output JSON")
out_file = gr.File(label="Output File")
gr.Button("Health").click(health, outputs=out_json, api_name="/health")
gr.Button("Prepare for OpenAI VLM").click(
prepare_for_openai_vlm, inputs=[inp, max_side, fmt, quality], outputs=out_json, api_name="/prepare_for_openai_vlm"
)
gr.Button("Metrics").click(image_metrics, inputs=inp, outputs=out_json, api_name="/image_metrics")
gr.Button("BG Remove").click(bg_remove, inputs=inp, outputs=[out_file, out_json], api_name="/bg_remove")
gr.Button("Trim Alpha").click(trim_alpha, inputs=inp, outputs=[out_file, out_json], api_name="/trim_alpha")
# Batch endpoints (API-only; UI is minimal)
batch_inp = gr.Files(label="Images (batch)", file_types=["image"])
batch_out = gr.Code(language="json", label="Batch JSON")
gr.Button("Prepare VLM Batch").click(
prepare_for_openai_vlm_batch, inputs=[batch_inp, max_side, fmt, quality], outputs=batch_out, api_name="/prepare_for_openai_vlm_batch"
)
gr.Button("Embed Batch").click(embed_images_batch, inputs=batch_inp, outputs=batch_out, api_name="/embed_images_batch")
# Spritesheet pack
names = gr.Textbox(label="Names JSON", value='["neutral","happy"]')
sheet_file = gr.File(label="Spritesheet PNG")
sheet_map = gr.Code(language="json", label="Spritesheet Map")
gr.Button("Pack Spritesheet").click(pack_spritesheet, inputs=[batch_inp, names], outputs=[sheet_file, sheet_map], api_name="/pack_spritesheet")
if __name__ == "__main__":
# HF Spaces runs behind a proxy; bind to 0.0.0.0 and the platform port.
port = int(os.environ.get("PORT", "7860"))
demo.queue(default_concurrency_limit=2, max_size=64).launch(server_name="0.0.0.0", server_port=port)