raspagem_supra / app.py
vcollos's picture
original v2
40114f7 verified
import io
import os
import shutil
import tempfile
from typing import Dict, List, Tuple
import numpy as np
import gradio as gr
from PIL import Image
import tensorflow as tf
from huggingface_hub import hf_hub_download
import spaces
# =========================
# VARIÁVEIS DE MODELO
# =========================
MODEL_URL = os.environ.get("MODEL_URL", "").strip()
MODEL_REPO = os.environ.get("MODEL_REPO", "").strip()
MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip()
MODEL_V2_FILE = os.environ.get("MODEL_FILE", "raspagem_2025_antes_depois.pb").strip()
MODEL_V1_FILE = os.environ.get("MODELO_V1", "").strip()
LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip()
IMG_SIZE = int(os.environ.get("IMG_SIZE", "224"))
TOPK = int(os.environ.get("TOPK", "0"))
# =========================
# Lazy state
# =========================
_SERVING_V1 = None
_SERVING_V2 = None
_LABELS_V1: List[str] = []
_LABELS_V2: List[str] = []
_LAST_INIT_ERROR_V1 = None
_LAST_INIT_ERROR_V2 = None
# =========================
# Utilitários
# =========================
def _download_from_url(url: str) -> str:
import requests
resp = requests.get(url, timeout=60)
resp.raise_for_status()
tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_")
local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb")
with open(local, "wb") as f:
f.write(resp.content)
return local
def _prepare_saved_model_dir(pb_path: str) -> str:
tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_")
shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb"))
return tmp_dir
def _load_model_from_file(pb_file: str) -> tf.types.experimental.ConcreteFunction:
sm_dir = _prepare_saved_model_dir(pb_file)
model = tf.saved_model.load(sm_dir)
serving = model.signatures.get("serving_default")
if serving is None:
raise RuntimeError(f"Modelo {pb_file} sem assinatura 'serving_default'.")
return serving
def _maybe_labels() -> List[str]:
try:
if LABELS_FILE:
if MODEL_REPO:
p = hf_hub_download(
repo_id=MODEL_REPO,
filename=LABELS_FILE,
repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model",
)
else:
p = LABELS_FILE
with open(p, "r", encoding="utf-8") as f:
return [x.strip() for x in f if x.strip()]
except Exception as e:
print(f"[labels] ignorando erro: {e}")
return []
# =========================
# Inicialização
# =========================
def _init_v1() -> Tuple[bool, str]:
global _SERVING_V1, _LABELS_V1, _LAST_INIT_ERROR_V1
if _SERVING_V1 is not None:
return True, "ok"
try:
if not os.path.exists(MODEL_V1_FILE):
raise FileNotFoundError(f"MODELO_V1 não encontrado: {MODEL_V1_FILE}")
_SERVING_V1 = _load_model_from_file(MODEL_V1_FILE)
_LABELS_V1 = _maybe_labels()
return True, "ok"
except Exception as e:
_LAST_INIT_ERROR_V1 = f"{type(e).__name__}: {e}"
return False, _LAST_INIT_ERROR_V1
def _init_v2() -> Tuple[bool, str]:
global _SERVING_V2, _LABELS_V2, _LAST_INIT_ERROR_V2
if _SERVING_V2 is not None:
return True, "ok"
try:
if MODEL_URL:
pb_path = _download_from_url(MODEL_URL)
elif MODEL_REPO:
pb_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_V2_FILE,
repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model",
)
elif os.path.exists(MODEL_V2_FILE):
pb_path = MODEL_V2_FILE
else:
raise FileNotFoundError("MODEL_FILE não encontrado")
_SERVING_V2 = _load_model_from_file(pb_path)
_LABELS_V2 = _maybe_labels()
return True, "ok"
except Exception as e:
_LAST_INIT_ERROR_V2 = f"{type(e).__name__}: {e}"
return False, _LAST_INIT_ERROR_V2
# =========================
# Processamento
# =========================
def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes:
img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
buf = io.BytesIO()
img.save(buf, format="JPEG")
return buf.getvalue()
def _pretty_label(raw: str) -> str:
s = (raw or "").strip().lower()
m = {
"necessario": "Necessário",
"necessário": "Necessário",
"nao_necessario": "Não necessário",
"não_necessário": "Não necessário",
"s1": "S1", "s2": "S2", "s3": "S3", "s4": "S4", "s5": "S5", "s6": "S6"
}
key = s.replace(" ", "").replace("ã", "a").replace("á", "a").replace("é", "e").replace("í", "i").replace("ó", "o").replace("ç", "c")
return m.get(key, raw.strip().capitalize())
def _format_result(label: str, score: float, tipo: str) -> str:
return f"{tipo}: {_pretty_label(label)} ({score * 100:.1f}%)"
# =========================
# Inferência combinada
# =========================
@spaces.GPU(duration=120)
def infer(image: Image.Image):
if image is None:
raise ValueError("Envie uma imagem.")
image_bytes = _preprocess_image_to_bytes(image)
# V1 - Sextante
ok1, err1 = _init_v1()
if not ok1:
raise RuntimeError(f"Erro ao carregar modelo V1: {err1}")
res1 = _SERVING_V1(
image_bytes=tf.convert_to_tensor([image_bytes]),
key=tf.convert_to_tensor(["v1"]),
)
scores1 = res1["scores"].numpy()[0]
labels1 = [x.decode("utf-8") for x in res1["labels"].numpy()[0]] if "labels" in res1 else _LABELS_V1
i1 = int(np.argmax(scores1))
sextante = _format_result(labels1[i1], scores1[i1], "Sextante")
# V2 - Necessidade
ok2, err2 = _init_v2()
if not ok2:
raise RuntimeError(f"Erro ao carregar modelo V2: {err2}")
res2 = _SERVING_V2(
image_bytes=tf.convert_to_tensor([image_bytes]),
key=tf.convert_to_tensor(["v2"]),
)
scores2 = res2["scores"].numpy()[0]
labels2 = [x.decode("utf-8") for x in res2["labels"].numpy()[0]] if "labels" in res2 else _LABELS_V2
i2 = int(np.argmax(scores2))
necessidade = _format_result(labels2[i2], scores2[i2], "Necessidade")
return f"{sextante}\n{necessidade}"
# =========================
# Gradio UI
# =========================
demo = gr.Blocks(title="RaspagemTF - V1 + V2")
with demo:
gr.Markdown("## RaspagemTF — Inferência em dois modelos (Sextante, Necessidade)")
with gr.Row():
img = gr.Image(type="pil", label="Imagem")
res = gr.Textbox(label="Resultados", lines=4)
btn = gr.Button("Rodar inferência")
btn.click(fn=infer, inputs=img, outputs=res)
if __name__ == "__main__":
demo.queue()
demo.launch()