vcollos commited on
Commit
a6af2d4
·
verified ·
1 Parent(s): c456a5a

BACKUP DE APP.PY

Browse files
Files changed (1) hide show
  1. main.py +233 -67
main.py CHANGED
@@ -1,85 +1,251 @@
 
1
  import os
2
- from functools import lru_cache
3
- from typing import Any, Dict, List
 
4
 
 
5
  import gradio as gr
6
- from google.cloud.aiplatform.gapic import PredictionServiceClient
7
- from google.cloud.aiplatform.gapic.schema.predict.instance import (
8
- ImageClassificationPredictionInstance,
9
- )
10
- from google.cloud.aiplatform.gapic.schema.predict.params import (
11
- ImageClassificationPredictionParams,
12
- )
13
- from google.protobuf.json_format import MessageToDict
14
-
15
-
16
- PROJECT_ID = os.getenv("VERTEX_PROJECT_ID", "366594249966")
17
- ENDPOINT_ID = os.getenv("VERTEX_ENDPOINT_ID", "5122839078575800320")
18
- LOCATION = os.getenv("VERTEX_LOCATION", "us-central1")
19
- CONFIDENCE_THRESHOLD = float(os.getenv("VERTEX_CONFIDENCE_THRESHOLD", "0.2"))
20
- MAX_PREDICTIONS = int(os.getenv("VERTEX_MAX_PREDICTIONS", "5"))
21
-
22
-
23
- @lru_cache(maxsize=1)
24
- def _prediction_client() -> PredictionServiceClient:
25
- api_endpoint = f"{LOCATION}-aiplatform.googleapis.com"
26
- return PredictionServiceClient(client_options={"api_endpoint": api_endpoint})
27
-
28
-
29
- def predict_image_classification_sample(
30
- project: str,
31
- endpoint_id: str,
32
- location: str,
33
- filename: str,
34
- ) -> Dict[str, Any]:
35
- client = _prediction_client()
36
- with open(filename, "rb") as f:
37
- content = f.read()
38
-
39
- instance = ImageClassificationPredictionInstance(content=content).to_value()
40
- params = ImageClassificationPredictionParams(
41
- confidence_threshold=CONFIDENCE_THRESHOLD,
42
- max_predictions=MAX_PREDICTIONS,
43
- ).to_value()
44
-
45
- endpoint_path = client.endpoint_path(project=project, location=location, endpoint=endpoint_id)
46
- response = client.predict(endpoint=endpoint_path, instances=[instance], parameters=params)
47
- predictions: List[Dict[str, Any]] = []
48
- for prediction in response.predictions:
49
- predictions.append(MessageToDict(prediction))
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return {
52
- "deployed_model_id": response.deployed_model_id,
53
- "predictions": predictions,
 
 
 
 
 
 
 
 
54
  }
55
 
56
 
57
- def _gradio_predict(image_path: str) -> Dict[str, Any]:
58
- if not image_path:
59
- raise ValueError("Envie uma imagem para realizar a inferência.")
 
 
 
 
60
 
61
- result = predict_image_classification_sample(
62
- project=PROJECT_ID,
63
- endpoint_id=ENDPOINT_ID,
64
- location=LOCATION,
65
- filename=image_path,
66
  )
67
 
68
- return result
 
 
 
69
 
 
 
 
 
70
 
71
- with gr.Blocks(title="Vertex AI Image Classification") as demo:
72
- gr.Markdown("## Vertex AI — Classificação de Imagens")
73
- gr.Markdown(
74
- "Envie uma imagem e o aplicativo encaminha a requisição diretamente para o endpoint configurado no Vertex AI."
75
- )
76
 
77
- with gr.Row():
78
- image_input = gr.Image(type="filepath", label="Upload da imagem")
79
- prediction_output = gr.JSON(label="Resposta do Vertex AI")
80
 
81
- predict_button = gr.Button("Realizar inferência")
82
- predict_button.click(_gradio_predict, inputs=image_input, outputs=prediction_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  if __name__ == "__main__":
 
85
  demo.launch()
 
1
+ import io
2
  import os
3
+ import shutil
4
+ import tempfile
5
+ from typing import Dict, List, Tuple
6
 
7
+ import numpy as np
8
  import gradio as gr
9
+ from PIL import Image
10
+ import tensorflow as tf
11
+ from huggingface_hub import hf_hub_download
12
+ import spaces
13
+
14
+ # =========================
15
+ # Config (via Variables)
16
+ # =========================
17
+ # Onde buscar o modelo (.pb):
18
+ # 1) MODEL_URL (http/https)
19
+ # 2) MODEL_REPO + MODEL_FILE (+ MODEL_REPO_TYPE: model|space)
20
+ # 3) Caminho local (MODEL_FILE) na raiz do Space
21
+ MODEL_URL = os.environ.get("MODEL_URL", "").strip()
22
+ MODEL_REPO = os.environ.get("MODEL_REPO", "").strip() # ex: "vcollos/raspagemTF" ou "spaces/vcollos/raspagem_supra"
23
+ MODEL_REPO_TYPE = os.environ.get("MODEL_REPO_TYPE", "model").strip() # "model" ou "space"
24
+ MODEL_FILE = os.environ.get("MODEL_FILE", "raspagem_2025_antes_depois.pb").strip()
25
+ LABELS_FILE = os.environ.get("LABELS_FILE", "labels.txt").strip()
26
+ IMG_SIZE = int(os.environ.get("IMG_SIZE", "224"))
27
+ TOPK = int(os.environ.get("TOPK", "0")) # 0 = lista tudo
28
+
29
+ # =========================
30
+ # Download/resolve SavedModel (.pb) e lazy init
31
+ # =========================
32
+
33
+ def _download_from_url(url: str) -> str:
34
+ import requests
35
+ resp = requests.get(url, timeout=60)
36
+ resp.raise_for_status()
37
+ tmp_dir = tempfile.mkdtemp(prefix="raspagem_dl_")
38
+ local = os.path.join(tmp_dir, os.path.basename(url) or "saved_model.pb")
39
+ with open(local, "wb") as f:
40
+ f.write(resp.content)
41
+ return local
42
+
43
+
44
+ def _download_model() -> str:
45
+ # Prioridade: URL -> HF repo -> arquivo local
46
+ if MODEL_URL:
47
+ return _download_from_url(MODEL_URL)
48
+
49
+ if MODEL_REPO:
50
+ try:
51
+ return hf_hub_download(
52
+ repo_id=MODEL_REPO,
53
+ filename=MODEL_FILE,
54
+ repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model",
55
+ )
56
+ except Exception as e:
57
+ print(f"[download] HF hub falhou: {e}")
58
+
59
+ if os.path.exists(MODEL_FILE):
60
+ return MODEL_FILE
61
+
62
+ raise FileNotFoundError(
63
+ "Modelo não encontrado. Defina MODEL_URL OU (MODEL_REPO, MODEL_REPO_TYPE, MODEL_FILE) OU deixe o arquivo na raiz do Space."
64
+ )
65
+
66
 
67
+ def _prepare_saved_model_dir(pb_path: str) -> str:
68
+ # SavedModel mínimo: diretório contendo 'saved_model.pb'
69
+ tmp_dir = tempfile.mkdtemp(prefix="raspagem_savedmodel_")
70
+ shutil.copy(pb_path, os.path.join(tmp_dir, "saved_model.pb"))
71
+ return tmp_dir
72
+
73
+
74
+ # Lazy state
75
+ _SERVING_FN = None
76
+ _LABELS: List[str] = []
77
+ _LAST_INIT_ERROR: str | None = None
78
+
79
+
80
+ def _maybe_labels() -> List[str]:
81
+ try:
82
+ if LABELS_FILE:
83
+ if MODEL_REPO:
84
+ p = hf_hub_download(
85
+ repo_id=MODEL_REPO,
86
+ filename=LABELS_FILE,
87
+ repo_type=MODEL_REPO_TYPE if MODEL_REPO_TYPE in {"model", "space"} else "model",
88
+ )
89
+ else:
90
+ p = LABELS_FILE
91
+ with open(p, "r", encoding="utf-8") as f:
92
+ return [x.strip() for x in f if x.strip()]
93
+ except Exception as e:
94
+ print(f"[labels] ignorando erro: {e}")
95
+ return []
96
+
97
+
98
+ def _init_once() -> Tuple[bool, str]:
99
+ global _SERVING_FN, _LABELS, _LAST_INIT_ERROR
100
+ if _SERVING_FN is not None:
101
+ return True, "ok"
102
+ try:
103
+ pb_local = _download_model()
104
+ sm_dir = _prepare_saved_model_dir(pb_local)
105
+ model = tf.saved_model.load(sm_dir)
106
+ serving = model.signatures.get("serving_default")
107
+ if serving is None:
108
+ raise RuntimeError("SavedModel sem assinatura 'serving_default'.")
109
+ _SERVING_FN = serving
110
+ _LABELS = _maybe_labels()
111
+ _LAST_INIT_ERROR = None
112
+ return True, "ok"
113
+ except Exception as e:
114
+ _LAST_INIT_ERROR = f"{type(e).__name__}: {e}"
115
+ return False, _LAST_INIT_ERROR
116
+
117
+
118
+ # =========================
119
+ # Pré/Pós-processamento
120
+ # =========================
121
+
122
+ def _preprocess_image_to_bytes(pil_img: Image.Image) -> bytes:
123
+ img = pil_img.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
124
+ buf = io.BytesIO()
125
+ img.save(buf, format="JPEG")
126
+ return buf.getvalue()
127
+
128
+
129
+ def _pretty_label(raw: str) -> str:
130
+ s = (raw or "").strip().lower()
131
+ m = {
132
+ "necessario": "Necessário",
133
+ "necessário": "Necessário",
134
+ "nao_necessario": "Não necessário",
135
+ "não_necessário": "Não necessário",
136
+ "s1": "S1",
137
+ "s2": "S2",
138
+ "s3": "S3",
139
+ }
140
+ # remove acentos/espacos no inicio se vier com variações
141
+ key = s.replace(" ", "").replace("ã", "a").replace("á", "a").replace("é", "e").replace("í", "i").replace("ó", "o").replace("ç", "c")
142
+ return m.get(key, raw.strip().capitalize())
143
+
144
+
145
+ def _format_bars(labels: List[str], scores: np.ndarray, topk: int) -> str:
146
+ # Ordena desc, aplica topk (0 = tudo), desenha barras de 20 colunas
147
+ idxs = np.argsort(scores)[::-1]
148
+ if topk and topk > 0:
149
+ idxs = idxs[:topk]
150
+ lines = []
151
+ for i in idxs:
152
+ pct = float(scores[i]) * 100.0
153
+ bar_len = max(1, int(scores[i] * 20))
154
+ bar = "█" * bar_len
155
+ label = _pretty_label(labels[i] if i < len(labels) and labels[i] else ( _LABELS[i] if i < len(_LABELS) else f"class_{i}" ))
156
+ lines.append(f"{label}: {pct:.1f}% {bar}")
157
+ return "\n".join(lines)
158
+
159
+
160
+ # =========================
161
+ # UI functions
162
+ # =========================
163
+
164
+ def _signature_info() -> Dict[str, Dict[str, str]]:
165
+ ok, err = _init_once()
166
+ if not ok:
167
+ return {"init_error": err}
168
+ inputs = {k: str(v) for k, v in _SERVING_FN.structured_input_signature[1].items()}
169
+ outputs = {k: str(v) for k, v in _SERVING_FN.structured_outputs.items()}
170
+ return {"inputs": inputs, "outputs": outputs}
171
+
172
+
173
+ def _diagnostics() -> Dict[str, object]:
174
+ ok, err = _init_once()
175
  return {
176
+ "ok": ok,
177
+ "error": err if not ok else None,
178
+ "env": {
179
+ "MODEL_URL": MODEL_URL or None,
180
+ "MODEL_REPO": MODEL_REPO or None,
181
+ "MODEL_REPO_TYPE": MODEL_REPO_TYPE,
182
+ "MODEL_FILE": MODEL_FILE,
183
+ "IMG_SIZE": IMG_SIZE,
184
+ "TOPK": TOPK,
185
+ },
186
  }
187
 
188
 
189
+ @spaces.GPU(duration=120)
190
+ def infer(image: Image.Image):
191
+ if image is None:
192
+ raise ValueError("Envie uma imagem.")
193
+ ok, err = _init_once()
194
+ if not ok:
195
+ raise RuntimeError(f"Modelo não inicializado: {err}")
196
 
197
+ image_bytes = _preprocess_image_to_bytes(image)
198
+ result = _SERVING_FN(
199
+ image_bytes=tf.convert_to_tensor([image_bytes]),
200
+ key=tf.convert_to_tensor(["0"]),
 
201
  )
202
 
203
+ scores_t = result.get("scores")
204
+ labels_t = result.get("labels")
205
+ if scores_t is None:
206
+ raise KeyError("Saída 'scores' não encontrada na assinatura do modelo.")
207
 
208
+ scores = scores_t.numpy()[0]
209
+ labels: List[str] = []
210
+ if labels_t is not None:
211
+ labels = [x.decode("utf-8") for x in labels_t.numpy()[0]]
212
 
213
+ return _format_bars(labels, scores, TOPK)
 
 
 
 
214
 
 
 
 
215
 
216
+ # =========================
217
+ # Gradio UI
218
+ # =========================
219
+
220
+ demo = gr.Blocks(title="RaspagemTF - SavedModel (.pb)")
221
+ with demo:
222
+ gr.Markdown("## RaspagemTF — Inferência (SavedModel .pb)")
223
+ with gr.Row():
224
+ img = gr.Image(type="pil", label="Imagem")
225
+ res = gr.Textbox(label="Resultados", lines=8)
226
+ btn = gr.Button("Rodar inferência")
227
+ btn.click(fn=infer, inputs=img, outputs=res)
228
+
229
+ with gr.Accordion("Diagnóstico", open=False):
230
+ d_btn = gr.Button("Rodar diagnóstico")
231
+ d_out = gr.JSON()
232
+ d_btn.click(fn=_diagnostics, inputs=None, outputs=d_out)
233
+
234
+ @spaces.GPU(duration=30)
235
+ def _gpu_diag():
236
+ return {
237
+ "tf_version": tf.__version__,
238
+ "gpus_detected": [str(g) for g in tf.config.list_physical_devices('GPU')]
239
+ }
240
+ g_btn = gr.Button("Checar GPU")
241
+ g_out = gr.JSON()
242
+ g_btn.click(fn=_gpu_diag, inputs=None, outputs=g_out)
243
+
244
+ with gr.Accordion("Assinaturas do modelo", open=False):
245
+ s_btn = gr.Button("Mostrar assinatura")
246
+ s_out = gr.JSON()
247
+ s_btn.click(fn=_signature_info, inputs=None, outputs=s_out)
248
 
249
  if __name__ == "__main__":
250
+ demo.queue()
251
  demo.launch()