Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import shutil | |
| import tempfile | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| # ========================= | |
| # VARIÁVEIS DE MODELO | |
| # ========================= | |
| MODEL_URL = os.environ.get("MODEL_URL", "").strip() | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "").strip() | |
| MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip() | |
| MODEL_V2_FILE = os.environ.get("MODEL_FILE", "raspagem_2025_antes_depois.pb").strip() | |
| MODEL_V1_FILE = os.environ.get("MODELO_V1", "").strip() | |
| LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip() | |
| IMG_SIZE = int(os.environ.get("IMG_SIZE", "224")) | |
| TOPK = int(os.environ.get("TOPK", "0")) | |
| # ========================= | |
| # Lazy state | |
| # ========================= | |
| _SERVING_V1 = None | |
| _SERVING_V2 = None | |
| _LABELS_V1: List[str] = [] | |
| _LABELS_V2: List[str] = [] | |
| _LAST_INIT_ERROR_V1 = None | |
| _LAST_INIT_ERROR_V2 = None | |
| # ========================= | |
| # Utilitários | |
| # ========================= | |
| def _download_from_url(url: str) -> str: | |
| import requests | |
| resp = requests.get(url, timeout=60) | |
| resp.raise_for_status() | |
| tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_") | |
| local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb") | |
| with open(local, "wb") as f: | |
| f.write(resp.content) | |
| return local | |
| def _prepare_saved_model_dir(pb_path: str) -> str: | |
| tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_") | |
| shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb")) | |
| return tmp_dir | |
| def _load_model_from_file(pb_file: str) -> tf.types.experimental.ConcreteFunction: | |
| sm_dir = _prepare_saved_model_dir(pb_file) | |
| model = tf.saved_model.load(sm_dir) | |
| serving = model.signatures.get("serving_default") | |
| if serving is None: | |
| raise RuntimeError(f"Modelo {pb_file} sem assinatura 'serving_default'.") | |
| return serving | |
| def _maybe_labels() -> List[str]: | |
| try: | |
| if LABELS_FILE: | |
| if MODEL_REPO: | |
| p = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=LABELS_FILE, | |
| repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", | |
| ) | |
| else: | |
| p = LABELS_FILE | |
| with open(p, "r", encoding="utf-8") as f: | |
| return [x.strip() for x in f if x.strip()] | |
| except Exception as e: | |
| print(f"[labels] ignorando erro: {e}") | |
| return [] | |
| # ========================= | |
| # Inicialização | |
| # ========================= | |
| def _init_v1() -> Tuple[bool, str]: | |
| global _SERVING_V1, _LABELS_V1, _LAST_INIT_ERROR_V1 | |
| if _SERVING_V1 is not None: | |
| return True, "ok" | |
| try: | |
| if not os.path.exists(MODEL_V1_FILE): | |
| raise FileNotFoundError(f"MODELO_V1 não encontrado: {MODEL_V1_FILE}") | |
| _SERVING_V1 = _load_model_from_file(MODEL_V1_FILE) | |
| _LABELS_V1 = _maybe_labels() | |
| return True, "ok" | |
| except Exception as e: | |
| _LAST_INIT_ERROR_V1 = f"{type(e).__name__}: {e}" | |
| return False, _LAST_INIT_ERROR_V1 | |
| def _init_v2() -> Tuple[bool, str]: | |
| global _SERVING_V2, _LABELS_V2, _LAST_INIT_ERROR_V2 | |
| if _SERVING_V2 is not None: | |
| return True, "ok" | |
| try: | |
| if MODEL_URL: | |
| pb_path = _download_from_url(MODEL_URL) | |
| elif MODEL_REPO: | |
| pb_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_V2_FILE, | |
| repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model", | |
| ) | |
| elif os.path.exists(MODEL_V2_FILE): | |
| pb_path = MODEL_V2_FILE | |
| else: | |
| raise FileNotFoundError("MODEL_FILE não encontrado") | |
| _SERVING_V2 = _load_model_from_file(pb_path) | |
| _LABELS_V2 = _maybe_labels() | |
| return True, "ok" | |
| except Exception as e: | |
| _LAST_INIT_ERROR_V2 = f"{type(e).__name__}: {e}" | |
| return False, _LAST_INIT_ERROR_V2 | |
| # ========================= | |
| # Processamento | |
| # ========================= | |
| def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes: | |
| img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE)) | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG") | |
| return buf.getvalue() | |
| def _pretty_label(raw: str) -> str: | |
| s = (raw or "").strip().lower() | |
| m = { | |
| "necessario": "Necessário", | |
| "necessário": "Necessário", | |
| "nao_necessario": "Não necessário", | |
| "não_necessário": "Não necessário", | |
| "s1": "S1", "s2": "S2", "s3": "S3", "s4": "S4", "s5": "S5", "s6": "S6" | |
| } | |
| key = s.replace(" ", "").replace("ã", "a").replace("á", "a").replace("é", "e").replace("í", "i").replace("ó", "o").replace("ç", "c") | |
| return m.get(key, raw.strip().capitalize()) | |
| def _format_result(label: str, score: float, tipo: str) -> str: | |
| return f"{tipo}: {_pretty_label(label)} ({score * 100:.1f}%)" | |
| # ========================= | |
| # Inferência combinada | |
| # ========================= | |
| def infer(image: Image.Image): | |
| if image is None: | |
| raise ValueError("Envie uma imagem.") | |
| image_bytes = _preprocess_image_to_bytes(image) | |
| # V1 - Sextante | |
| ok1, err1 = _init_v1() | |
| if not ok1: | |
| raise RuntimeError(f"Erro ao carregar modelo V1: {err1}") | |
| res1 = _SERVING_V1( | |
| image_bytes=tf.convert_to_tensor([image_bytes]), | |
| key=tf.convert_to_tensor(["v1"]), | |
| ) | |
| scores1 = res1["scores"].numpy()[0] | |
| labels1 = [x.decode("utf-8") for x in res1["labels"].numpy()[0]] if "labels" in res1 else _LABELS_V1 | |
| i1 = int(np.argmax(scores1)) | |
| sextante = _format_result(labels1[i1], scores1[i1], "Sextante") | |
| # V2 - Necessidade | |
| ok2, err2 = _init_v2() | |
| if not ok2: | |
| raise RuntimeError(f"Erro ao carregar modelo V2: {err2}") | |
| res2 = _SERVING_V2( | |
| image_bytes=tf.convert_to_tensor([image_bytes]), | |
| key=tf.convert_to_tensor(["v2"]), | |
| ) | |
| scores2 = res2["scores"].numpy()[0] | |
| labels2 = [x.decode("utf-8") for x in res2["labels"].numpy()[0]] if "labels" in res2 else _LABELS_V2 | |
| i2 = int(np.argmax(scores2)) | |
| necessidade = _format_result(labels2[i2], scores2[i2], "Necessidade") | |
| return f"{sextante}\n{necessidade}" | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| demo = gr.Blocks(title="RaspagemTF - V1 + V2") | |
| with demo: | |
| gr.Markdown("## RaspagemTF — Inferência em dois modelos (Sextante, Necessidade)") | |
| with gr.Row(): | |
| img = gr.Image(type="pil", label="Imagem") | |
| res = gr.Textbox(label="Resultados", lines=4) | |
| btn = gr.Button("Rodar inferência") | |
| btn.click(fn=infer, inputs=img, outputs=res) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch() |