|
|
from __future__ import annotations |
|
|
|
|
|
import base64 |
|
|
import hashlib |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
from dataclasses import dataclass |
|
|
from typing import Any |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SIGLIP_MODEL_ID = "google/siglip-base-patch16-224" |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class _Embedder: |
|
|
processor: Any |
|
|
model: Any |
|
|
|
|
|
|
|
|
_EMBEDDER: _Embedder | None = None |
|
|
|
|
|
|
|
|
def _get_embedder() -> _Embedder: |
|
|
global _EMBEDDER |
|
|
if _EMBEDDER is not None: |
|
|
return _EMBEDDER |
|
|
|
|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID) |
|
|
model = AutoModel.from_pretrained(SIGLIP_MODEL_ID) |
|
|
model.eval() |
|
|
torch.set_grad_enabled(False) |
|
|
_EMBEDDER = _Embedder(processor=processor, model=model) |
|
|
return _EMBEDDER |
|
|
|
|
|
|
|
|
def _to_pil(x: Any) -> Image.Image: |
|
|
if isinstance(x, Image.Image): |
|
|
return x |
|
|
if isinstance(x, dict) and isinstance(x.get("path"), str): |
|
|
return Image.open(x["path"]).convert("RGBA") |
|
|
if isinstance(x, str): |
|
|
return Image.open(x).convert("RGBA") |
|
|
raise TypeError(f"Unsupported image input: {type(x).__name__}") |
|
|
|
|
|
|
|
|
def _sha256_bytes(b: bytes) -> str: |
|
|
return hashlib.sha256(b).hexdigest() |
|
|
|
|
|
|
|
|
def _sha256_image(img: Image.Image) -> str: |
|
|
buf = io.BytesIO() |
|
|
img.save(buf, format="PNG") |
|
|
return _sha256_bytes(buf.getvalue()) |
|
|
|
|
|
|
|
|
def _l2_normalize(v: np.ndarray) -> np.ndarray: |
|
|
n = np.linalg.norm(v, axis=-1, keepdims=True) |
|
|
n = np.maximum(n, 1e-12) |
|
|
return v / n |
|
|
|
|
|
|
|
|
def _embed_pils(pils: list[Image.Image]) -> list[dict[str, Any]]: |
|
|
import torch |
|
|
|
|
|
emb = _get_embedder() |
|
|
inputs = emb.processor(images=[p.convert("RGB") for p in pils], return_tensors="pt") |
|
|
with torch.no_grad(): |
|
|
|
|
|
if hasattr(emb.model, "get_image_features"): |
|
|
feats = emb.model.get_image_features(**inputs) |
|
|
else: |
|
|
out = emb.model(**inputs) |
|
|
feats = getattr(out, "pooler_output", None) or out.last_hidden_state[:, 0, :] |
|
|
feats = feats.detach().cpu().numpy().astype("float32") |
|
|
feats = _l2_normalize(feats) |
|
|
out: list[dict[str, Any]] = [] |
|
|
for p, vec in zip(pils, feats): |
|
|
out.append( |
|
|
{ |
|
|
"dims": int(vec.shape[0]), |
|
|
"norm": "l2", |
|
|
"model_id": SIGLIP_MODEL_ID, |
|
|
"sha256": _sha256_image(p), |
|
|
"vector": vec.tolist(), |
|
|
} |
|
|
) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dhash(img: Image.Image, size: int = 8) -> str: |
|
|
g = img.convert("L").resize((size + 1, size), Image.BILINEAR) |
|
|
a = np.asarray(g, dtype=np.int16) |
|
|
diff = a[:, 1:] > a[:, :-1] |
|
|
bits = "".join("1" if x else "0" for x in diff.flatten().tolist()) |
|
|
return hex(int(bits, 2))[2:].rjust(size * size // 4, "0") |
|
|
|
|
|
|
|
|
def _laplacian_var(img: Image.Image) -> float: |
|
|
g = img.convert("L") |
|
|
a = np.asarray(g, dtype=np.float32) |
|
|
k = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32) |
|
|
|
|
|
h, w = a.shape |
|
|
if h < 3 or w < 3: |
|
|
return 0.0 |
|
|
out = ( |
|
|
a[1 : h - 1, 0 : w - 2] * k[1, 0] |
|
|
+ a[0 : h - 2, 1 : w - 1] * k[0, 1] |
|
|
+ a[1 : h - 1, 1 : w - 1] * k[1, 1] |
|
|
+ a[2:h, 1 : w - 1] * k[2, 1] |
|
|
+ a[1 : h - 1, 2:w] * k[1, 2] |
|
|
) |
|
|
return float(np.var(out)) |
|
|
|
|
|
|
|
|
def image_metrics(image: Any) -> str: |
|
|
img = _to_pil(image) |
|
|
arr = np.asarray(img.convert("RGB"), dtype=np.float32) / 255.0 |
|
|
has_alpha = img.mode in ("RGBA", "LA") |
|
|
alpha_cov = 1.0 |
|
|
if has_alpha: |
|
|
a = np.asarray(img.split()[-1], dtype=np.float32) / 255.0 |
|
|
alpha_cov = float(np.mean(a > 0.05)) |
|
|
metrics = { |
|
|
"width": img.width, |
|
|
"height": img.height, |
|
|
"blur_laplacian_var": _laplacian_var(img), |
|
|
"contrast_std": float(np.std(arr)), |
|
|
"mean_brightness": float(np.mean(arr)), |
|
|
"dhash": _dhash(img), |
|
|
"has_alpha": bool(has_alpha), |
|
|
"alpha_coverage": alpha_cov, |
|
|
"sha256": _sha256_image(img), |
|
|
} |
|
|
return json.dumps(metrics) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resize_max_side(img: Image.Image, max_side: int) -> Image.Image: |
|
|
max_side = int(max_side) |
|
|
if max_side <= 0: |
|
|
return img |
|
|
w, h = img.size |
|
|
m = max(w, h) |
|
|
if m <= max_side: |
|
|
return img |
|
|
scale = max_side / float(m) |
|
|
nw = max(1, int(round(w * scale))) |
|
|
nh = max(1, int(round(h * scale))) |
|
|
return img.resize((nw, nh), Image.LANCZOS) |
|
|
|
|
|
|
|
|
def prepare_for_openai_vlm(image: Any, max_side: int = 768, fmt: str = "webp", quality: int = 85) -> str: |
|
|
img = _to_pil(image) |
|
|
img = _resize_max_side(img, max_side=max_side) |
|
|
fmt = (fmt or "webp").lower() |
|
|
quality = int(quality) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
mime = "image/webp" |
|
|
if fmt == "jpeg" or fmt == "jpg": |
|
|
mime = "image/jpeg" |
|
|
img.convert("RGB").save(buf, format="JPEG", quality=quality, optimize=True) |
|
|
elif fmt == "png": |
|
|
mime = "image/png" |
|
|
img.save(buf, format="PNG", optimize=True) |
|
|
else: |
|
|
mime = "image/webp" |
|
|
img.convert("RGB").save(buf, format="WEBP", quality=quality, method=6) |
|
|
|
|
|
b = buf.getvalue() |
|
|
url = f"data:{mime};base64," + base64.b64encode(b).decode("ascii") |
|
|
out = { |
|
|
"url": url, |
|
|
"mime": mime, |
|
|
"width": img.width, |
|
|
"height": img.height, |
|
|
"sha256": _sha256_bytes(b), |
|
|
} |
|
|
return json.dumps(out) |
|
|
|
|
|
|
|
|
def prepare_for_openai_vlm_batch(images: list[Any], max_side: int = 768, fmt: str = "webp", quality: int = 85) -> str: |
|
|
out = [] |
|
|
for x in images or []: |
|
|
out.append(json.loads(prepare_for_openai_vlm(x, max_side=max_side, fmt=fmt, quality=quality))) |
|
|
return json.dumps(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bg_remove(image: Any) -> tuple[str, str]: |
|
|
from rembg import remove |
|
|
|
|
|
img = _to_pil(image).convert("RGBA") |
|
|
buf = io.BytesIO() |
|
|
img.save(buf, format="PNG") |
|
|
out_bytes = remove(buf.getvalue()) |
|
|
|
|
|
|
|
|
out_path = "bg_removed.png" |
|
|
with open(out_path, "wb") as f: |
|
|
f.write(out_bytes) |
|
|
|
|
|
meta = {"method": "rembg", "sha256_in": _sha256_image(img), "sha256_out": _sha256_bytes(out_bytes)} |
|
|
return out_path, json.dumps(meta) |
|
|
|
|
|
|
|
|
def trim_alpha(image: Any) -> tuple[str, str]: |
|
|
img = _to_pil(image).convert("RGBA") |
|
|
a = np.asarray(img.split()[-1], dtype=np.uint8) |
|
|
ys, xs = np.where(a > 0) |
|
|
if len(xs) == 0 or len(ys) == 0: |
|
|
out_path = "trimmed.png" |
|
|
img.save(out_path, format="PNG") |
|
|
meta = {"bbox": [0, 0, img.width, img.height], "orig_size": [img.width, img.height]} |
|
|
return out_path, json.dumps(meta) |
|
|
|
|
|
x0, x1 = int(xs.min()), int(xs.max()) |
|
|
y0, y1 = int(ys.min()), int(ys.max()) |
|
|
|
|
|
w = x1 - x0 + 1 |
|
|
h = y1 - y0 + 1 |
|
|
cropped = img.crop((x0, y0, x0 + w, y0 + h)) |
|
|
|
|
|
out_path = "trimmed.png" |
|
|
cropped.save(out_path, format="PNG") |
|
|
meta = {"bbox": [x0, y0, w, h], "orig_size": [img.width, img.height]} |
|
|
return out_path, json.dumps(meta) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pack_spritesheet(images: list[Any], names_json: str) -> tuple[str, str]: |
|
|
names = [] |
|
|
try: |
|
|
names = json.loads(names_json or "[]") |
|
|
except Exception: |
|
|
names = [] |
|
|
if not isinstance(names, list): |
|
|
names = [] |
|
|
|
|
|
pils = [_to_pil(x).convert("RGBA") for x in (images or [])] |
|
|
if not pils: |
|
|
return "", json.dumps({"error": "no_images"}) |
|
|
|
|
|
|
|
|
cols = min(4, len(pils)) |
|
|
rows = int(np.ceil(len(pils) / cols)) |
|
|
cell_w = max(p.width for p in pils) |
|
|
cell_h = max(p.height for p in pils) |
|
|
sheet = Image.new("RGBA", (cell_w * cols, cell_h * rows), (0, 0, 0, 0)) |
|
|
|
|
|
mapping: dict[str, Any] = {"cell": [cell_w, cell_h], "items": {}} |
|
|
for i, p in enumerate(pils): |
|
|
r = i // cols |
|
|
c = i % cols |
|
|
x = c * cell_w |
|
|
y = r * cell_h |
|
|
sheet.alpha_composite(p, (x, y)) |
|
|
key = str(names[i]) if i < len(names) else f"item_{i}" |
|
|
mapping["items"][key] = {"x": x, "y": y, "w": p.width, "h": p.height} |
|
|
|
|
|
out_path = "spritesheet.png" |
|
|
sheet.save(out_path, format="PNG") |
|
|
return out_path, json.dumps(mapping) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def health() -> str: |
|
|
return json.dumps({"ok": True, "embed_dims": 768, "model_id": SIGLIP_MODEL_ID}) |
|
|
|
|
|
|
|
|
def embed_images_batch(images: list[Any]) -> str: |
|
|
pils = [_to_pil(x) for x in (images or [])] |
|
|
out = _embed_pils(pils) |
|
|
return json.dumps(out) |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Image Processing Service") |
|
|
|
|
|
with gr.Tab("API"): |
|
|
inp = gr.File(label="Image", file_types=["image"]) |
|
|
max_side = gr.Slider(128, 2048, value=768, step=64, label="max_side (VLM prep)") |
|
|
fmt = gr.Dropdown(["webp", "jpeg", "png"], value="webp", label="format") |
|
|
quality = gr.Slider(10, 100, value=85, step=1, label="quality") |
|
|
|
|
|
out_json = gr.Code(language="json", label="Output JSON") |
|
|
out_file = gr.File(label="Output File") |
|
|
|
|
|
gr.Button("Health").click(health, outputs=out_json, api_name="/health") |
|
|
gr.Button("Prepare for OpenAI VLM").click( |
|
|
prepare_for_openai_vlm, inputs=[inp, max_side, fmt, quality], outputs=out_json, api_name="/prepare_for_openai_vlm" |
|
|
) |
|
|
gr.Button("Metrics").click(image_metrics, inputs=inp, outputs=out_json, api_name="/image_metrics") |
|
|
gr.Button("BG Remove").click(bg_remove, inputs=inp, outputs=[out_file, out_json], api_name="/bg_remove") |
|
|
gr.Button("Trim Alpha").click(trim_alpha, inputs=inp, outputs=[out_file, out_json], api_name="/trim_alpha") |
|
|
|
|
|
|
|
|
batch_inp = gr.Files(label="Images (batch)", file_types=["image"]) |
|
|
batch_out = gr.Code(language="json", label="Batch JSON") |
|
|
gr.Button("Prepare VLM Batch").click( |
|
|
prepare_for_openai_vlm_batch, inputs=[batch_inp, max_side, fmt, quality], outputs=batch_out, api_name="/prepare_for_openai_vlm_batch" |
|
|
) |
|
|
gr.Button("Embed Batch").click(embed_images_batch, inputs=batch_inp, outputs=batch_out, api_name="/embed_images_batch") |
|
|
|
|
|
|
|
|
names = gr.Textbox(label="Names JSON", value='["neutral","happy"]') |
|
|
sheet_file = gr.File(label="Spritesheet PNG") |
|
|
sheet_map = gr.Code(language="json", label="Spritesheet Map") |
|
|
gr.Button("Pack Spritesheet").click(pack_spritesheet, inputs=[batch_inp, names], outputs=[sheet_file, sheet_map], api_name="/pack_spritesheet") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
port = int(os.environ.get("PORT", "7860")) |
|
|
demo.queue(default_concurrency_limit=2, max_size=64).launch(server_name="0.0.0.0", server_port=port) |
|
|
|