vcollos commited on
Commit
075776b
·
verified ·
1 Parent(s): 7e490aa
Files changed (1) hide show
  1. app.py +225 -0
app.py CHANGED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import shutil
4
+ import tempfile
5
+ from typing import Dict, List, Tuple
6
+
7
+ import numpy as np
8
+ import gradio as gr
9
+ from PIL import Image
10
+ import tensorflow as tf
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # =========================
14
+ # Config (via Variables)
15
+ # =========================
16
+ # Suporta três formas de apontar o modelo:
17
+ # 1) MODEL_URL (http/https) -> baixa direto por URL
18
+ # 2) MODEL_REPO + MODEL_FILE (+ MODEL_REPO_TYPE: model|space)
19
+ # 3) Caminho local (MODEL_FILE existente na raiz do Space)
20
+ MODEL_URL = os.environ.get("MODEL_URL", "").strip()
21
+ MODEL_REPO = os.environ.get("MODEL_REPO", "").strip() # ex: "vcollos/raspagemTF" ou "spaces/vcollos/raspagem_supra"
22
+ MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip() # "model" ou "space"
23
+ MODEL_FILE = os.environ.get("MODEL_FILE", "raspagem_model_v1.pb").strip()
24
+ LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip()
25
+ IMG_SIZE = int(os.environ.get("IMG_SIZE", "224"))
26
+ TOPK = int(os.environ.get("TOPK", "5"))
27
+
28
+ # =========================
29
+ # Baixa/resolve o SavedModel (.pb) e carrega via lazy init
30
+ # =========================
31
+
32
+ def _download_from_url(url: str) -> str:
33
+ import requests
34
+ resp = requests.get(url, timeout=60)
35
+ resp.raise_for_status()
36
+ tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_")
37
+ local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb")
38
+ with open(local, "wb") as f:
39
+ f.write(resp.content)
40
+ return local
41
+
42
+
43
+ def _download_model() -> str:
44
+ # Prioridade: URL explícita -> HF repo -> arquivo local
45
+ if MODEL_URL:
46
+ return _download_from_url(MODEL_URL)
47
+
48
+ if MODEL_REPO:
49
+ try:
50
+ return hf_hub_download(
51
+ repo_id=MODEL_REPO,
52
+ filename=MODEL_FILE,
53
+ repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model",
54
+ )
55
+ except Exception as e:
56
+ print(f"[download] HF hub falhou: {e}")
57
+
58
+ if os.path.exists(MODEL_FILE):
59
+ return MODEL_FILE
60
+
61
+ raise FileNotFoundError(
62
+ "Modelo não encontrado. Defina MODEL_URL OU (MODEL_REPO, MODEL_REPO_TYPE, MODEL_FILE) OU deixe o arquivo na raiz do Space."
63
+ )
64
+
65
+
66
+ def _prepare_saved_model_dir(pb_path: str) -> str:
67
+ # SavedModel mínimo: diretório contendo 'saved_model.pb'
68
+ tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_")
69
+ shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb"))
70
+ return tmp_dir
71
+
72
+
73
+ # Lazy state
74
+ _SERVING_FN = None
75
+ _LABELS: List[str] = []
76
+ _LAST_INIT_ERROR: str | None = None
77
+
78
+
79
+ def _maybe_labels() -> List[str]:
80
+ # Tenta arquivo labels.txt no HF repo/local
81
+ try:
82
+ if LABELS_FILE:
83
+ if MODEL_REPO:
84
+ p = hf_hub_download(
85
+ repo_id=MODEL_REPO,
86
+ filename=LABELS_FILE,
87
+ repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model",
88
+ )
89
+ else:
90
+ p = LABELS_FILE
91
+ with open(p, "r", encoding="utf-8") as f:
92
+ return [x.strip() for x in f if x.strip()]
93
+ except Exception as e:
94
+ print(f"[labels] ignorando erro: {e}")
95
+ return []
96
+
97
+
98
+ def _init_once() -> Tuple[bool, str]:
99
+ global _SERVING_FN, _LABELS, _LAST_INIT_ERROR
100
+ if _SERVING_FN is not None:
101
+ return True, "ok"
102
+ try:
103
+ pb_local = _download_model()
104
+ sm_dir = _prepare_saved_model_dir(pb_local)
105
+ model = tf.saved_model.load(sm_dir)
106
+ # assinatura padrão esperada pelo Dancer Flow/Vertex TF Serving
107
+ serving = model.signatures.get("serving_default")
108
+ if serving is None:
109
+ raise RuntimeError("SavedModel sem assinatura 'serving_default'.")
110
+ _SERVING_FN = serving
111
+ _LABELS = _maybe_labels()
112
+ _LAST_INIT_ERROR = None
113
+ return True, "ok"
114
+ except Exception as e:
115
+ _LAST_INIT_ERROR = f"{type(e).__name__}: {e}"
116
+ return False, _LAST_INIT_ERROR
117
+
118
+
119
+ # =========================
120
+ # Pré/Pós-processamento
121
+ # =========================
122
+
123
+ def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes:
124
+ img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
125
+ buf = io.BytesIO()
126
+ img.save(buf, format="JPEG")
127
+ return buf.getvalue()
128
+
129
+
130
+ def _postprocess(scores: np.ndarray, model_labels: List[str]) -> List[Dict[str, float]]:
131
+ idxs = np.argsort(scores)[-TOPK:][::-1]
132
+ out: List[Dict[str, float]] = []
133
+ for i in idxs:
134
+ label = model_labels[i] if i < len(model_labels) and model_labels[i] else (
135
+ _LABELS[i] if i < len(_LABELS) else f"class_{i}"
136
+ )
137
+ out.append({"index": int(i), "label": label, "score": float(scores[i])})
138
+ return out
139
+
140
+
141
+ # =========================
142
+ # Funções de UI
143
+ # =========================
144
+
145
+ def _signature_info() -> Dict[str, Dict[str, str]]:
146
+ ok, err = _init_once()
147
+ if not ok:
148
+ return {"init_error": err}
149
+ inputs = {k: str(v) for k, v in _SERVING_FN.structured_input_signature[1].items()}
150
+ outputs = {k: str(v) for k, v in _SERVING_FN.structured_outputs.items()}
151
+ return {"inputs": inputs, "outputs": outputs}
152
+
153
+
154
+ def _diagnostics() -> Dict[str, object]:
155
+ ok, err = _init_once()
156
+ return {
157
+ "ok": ok,
158
+ "error": err if not ok else None,
159
+ "env": {
160
+ "MODEL_URL": MODEL_URL or None,
161
+ "MODEL_REPO": MODEL_REPO or None,
162
+ "MODEL_REPO_TYPE": MODEL_REPO_TYPE,
163
+ "MODEL_FILE": MODEL_FILE,
164
+ "IMG_SIZE": IMG_SIZE,
165
+ "TOPK": TOPK,
166
+ },
167
+ }
168
+
169
+
170
+ def infer(image: Image.Image):
171
+ if image is None:
172
+ raise ValueError("Envie uma imagem.")
173
+ ok, err = _init_once()
174
+ if not ok:
175
+ raise RuntimeError(f"Modelo não inicializado: {err}")
176
+
177
+ image_bytes = _preprocess_image_to_bytes(image)
178
+
179
+ # Assinatura típica de TF Serving com bytes:
180
+ # inputs: image_bytes: tf.string, key: tf.string
181
+ # outputs: scores: tf.float32 [1, N], labels: tf.string [1, N]
182
+ result = _SERVING_FN(
183
+ image_bytes=tf.convert_to_tensor([image_bytes]),
184
+ key=tf.convert_to_tensor(["0"]),
185
+ )
186
+
187
+ # Converte tensores nomeados
188
+ scores = result.get("scores")
189
+ labels = result.get("labels")
190
+ if scores is None:
191
+ raise KeyError("Saída 'scores' não encontrada na assinatura do modelo.")
192
+ np_scores = scores.numpy()[0]
193
+
194
+ model_labels: List[str] = []
195
+ if labels is not None:
196
+ model_labels = [x.decode("utf-8") for x in labels.numpy()[0]]
197
+
198
+ return _postprocess(np_scores, model_labels)
199
+
200
+
201
+ # =========================
202
+ # Gradio UI
203
+ # =========================
204
+
205
+ demo = gr.Blocks(title="RaspagemTF - SavedModel (.pb)")
206
+ with demo:
207
+ gr.Markdown("## RaspagemTF — Inferência (SavedModel .pb)")
208
+ with gr.Row():
209
+ img = gr.Image(type="pil", label="Imagem")
210
+ res = gr.JSON(label="Top-K")
211
+ btn = gr.Button("Rodar inferência")
212
+ btn.click(fn=infer, inputs=img, outputs=res)
213
+
214
+ with gr.Accordion("Diagnóstico", open=False):
215
+ d_btn = gr.Button("Rodar diagnóstico")
216
+ d_out = gr.JSON()
217
+ d_btn.click(fn=_diagnostics, inputs=None, outputs=d_out)
218
+
219
+ with gr.Accordion("Assinaturas do modelo", open=False):
220
+ s_btn = gr.Button("Mostrar assinatura")
221
+ s_out = gr.JSON()
222
+ s_btn.click(fn=_signature_info, inputs=None, outputs=s_out)
223
+
224
+ if __name__ == "__main__":
225
+ demo.launch()