|
|
import io |
|
|
import os |
|
|
import shutil |
|
|
import tempfile |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import tensorflow as tf |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_URL = os.environ.get("MODEL_URL", "").strip() |
|
|
MODEL_REPO = os.environ.get("MODEL_REPO", "").strip() |
|
|
MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip() |
|
|
MODEL_FILE = os.environ.get("MODEL_FILE", "raspagem_model_v1.pb").strip() |
|
|
LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip() |
|
|
IMG_SIZE = int(os.environ.get("IMG_SIZE", "224")) |
|
|
TOPK = int(os.environ.get("TOPK", "5")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_from_url(url: str) -> str: |
|
|
import requests |
|
|
resp = requests.get(url, timeout=60) |
|
|
resp.raise_for_status() |
|
|
tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_") |
|
|
local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb") |
|
|
with open(local, "wb") as f: |
|
|
f.write(resp.content) |
|
|
return local |
|
|
|
|
|
|
|
|
def _download_model() -> str: |
|
|
|
|
|
if MODEL_URL: |
|
|
return _download_from_url(MODEL_URL) |
|
|
|
|
|
if MODEL_REPO: |
|
|
try: |
|
|
return hf_hub_download( |
|
|
repo_id=MODEL_REPO, |
|
|
filename=MODEL_FILE, |
|
|
repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"[download] HF hub falhou: {e}") |
|
|
|
|
|
if os.path.exists(MODEL_FILE): |
|
|
return MODEL_FILE |
|
|
|
|
|
raise FileNotFoundError( |
|
|
"Modelo não encontrado. Defina MODEL_URL OU (MODEL_REPO, MODEL_REPO_TYPE, MODEL_FILE) OU deixe o arquivo na raiz do Space." |
|
|
) |
|
|
|
|
|
|
|
|
def _prepare_saved_model_dir(pb_path: str) -> str: |
|
|
|
|
|
tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_") |
|
|
shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb")) |
|
|
return tmp_dir |
|
|
|
|
|
|
|
|
|
|
|
_SERVING_FN = None |
|
|
_LABELS: List[str] = [] |
|
|
_LAST_INIT_ERROR: str | None = None |
|
|
|
|
|
|
|
|
def _maybe_labels() -> List[str]: |
|
|
|
|
|
try: |
|
|
if LABELS_FILE: |
|
|
if MODEL_REPO: |
|
|
p = hf_hub_download( |
|
|
repo_id=MODEL_REPO, |
|
|
filename=LABELS_FILE, |
|
|
repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", |
|
|
) |
|
|
else: |
|
|
p = LABELS_FILE |
|
|
with open(p, "r", encoding="utf-8") as f: |
|
|
return [x.strip() for x in f if x.strip()] |
|
|
except Exception as e: |
|
|
print(f"[labels] ignorando erro: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
def _init_once() -> Tuple[bool, str]: |
|
|
global _SERVING_FN, _LABELS, _LAST_INIT_ERROR |
|
|
if _SERVING_FN is not None: |
|
|
return True, "ok" |
|
|
try: |
|
|
pb_local = _download_model() |
|
|
sm_dir = _prepare_saved_model_dir(pb_local) |
|
|
model = tf.saved_model.load(sm_dir) |
|
|
|
|
|
serving = model.signatures.get("serving_default") |
|
|
if serving is None: |
|
|
raise RuntimeError("SavedModel sem assinatura 'serving_default'.") |
|
|
_SERVING_FN = serving |
|
|
_LABELS = _maybe_labels() |
|
|
_LAST_INIT_ERROR = None |
|
|
return True, "ok" |
|
|
except Exception as e: |
|
|
_LAST_INIT_ERROR = f"{type(e).__name__}: {e}" |
|
|
return False, _LAST_INIT_ERROR |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes: |
|
|
img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE)) |
|
|
buf = io.BytesIO() |
|
|
img.save(buf, format="JPEG") |
|
|
return buf.getvalue() |
|
|
|
|
|
|
|
|
def _postprocess(scores: np.ndarray, model_labels: List[str]) -> List[Dict[str, float]]: |
|
|
idxs = np.argsort(scores)[-TOPK:][::-1] |
|
|
out: List[Dict[str, float]] = [] |
|
|
for i in idxs: |
|
|
label = model_labels[i] if i < len(model_labels) and model_labels[i] else ( |
|
|
_LABELS[i] if i < len(_LABELS) else f"class_{i}" |
|
|
) |
|
|
out.append({"index": int(i), "label": label, "score": float(scores[i])}) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _signature_info() -> Dict[str, Dict[str, str]]: |
|
|
ok, err = _init_once() |
|
|
if not ok: |
|
|
return {"init_error": err} |
|
|
inputs = {k: str(v) for k, v in _SERVING_FN.structured_input_signature[1].items()} |
|
|
outputs = {k: str(v) for k, v in _SERVING_FN.structured_outputs.items()} |
|
|
return {"inputs": inputs, "outputs": outputs} |
|
|
|
|
|
|
|
|
def _diagnostics() -> Dict[str, object]: |
|
|
ok, err = _init_once() |
|
|
return { |
|
|
"ok": ok, |
|
|
"error": err if not ok else None, |
|
|
"env": { |
|
|
"MODEL_URL": MODEL_URL or None, |
|
|
"MODEL_REPO": MODEL_REPO or None, |
|
|
"MODEL_REPO_TYPE": MODEL_REPO_TYPE, |
|
|
"MODEL_FILE": MODEL_FILE, |
|
|
"IMG_SIZE": IMG_SIZE, |
|
|
"TOPK": TOPK, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def infer(image: Image.Image): |
|
|
if image is None: |
|
|
raise ValueError("Envie uma imagem.") |
|
|
ok, err = _init_once() |
|
|
if not ok: |
|
|
raise RuntimeError(f"Modelo não inicializado: {err}") |
|
|
|
|
|
image_bytes = _preprocess_image_to_bytes(image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = _SERVING_FN( |
|
|
image_bytes=tf.convert_to_tensor([image_bytes]), |
|
|
key=tf.convert_to_tensor(["0"]), |
|
|
) |
|
|
|
|
|
|
|
|
scores = result.get("scores") |
|
|
labels = result.get("labels") |
|
|
if scores is None: |
|
|
raise KeyError("Saída 'scores' não encontrada na assinatura do modelo.") |
|
|
np_scores = scores.numpy()[0] |
|
|
|
|
|
model_labels: List[str] = [] |
|
|
if labels is not None: |
|
|
model_labels = [x.decode("utf-8") for x in labels.numpy()[0]] |
|
|
|
|
|
return _postprocess(np_scores, model_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Blocks(title="RaspagemTF - SavedModel (.pb)") |
|
|
with demo: |
|
|
gr.Markdown("## RaspagemTF — Inferência (SavedModel .pb)") |
|
|
with gr.Row(): |
|
|
img = gr.Image(type="pil", label="Imagem") |
|
|
res = gr.JSON(label="Top-K") |
|
|
btn = gr.Button("Rodar inferência") |
|
|
btn.click(fn=infer, inputs=img, outputs=res) |
|
|
|
|
|
with gr.Accordion("Diagnóstico", open=False): |
|
|
d_btn = gr.Button("Rodar diagnóstico") |
|
|
d_out = gr.JSON() |
|
|
d_btn.click(fn=_diagnostics, inputs=None, outputs=d_out) |
|
|
|
|
|
with gr.Accordion("Assinaturas do modelo", open=False): |
|
|
s_btn = gr.Button("Mostrar assinatura") |
|
|
s_out = gr.JSON() |
|
|
s_btn.click(fn=_signature_info, inputs=None, outputs=s_out) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|