addgbf commited on
Commit
5f5b190
·
verified ·
1 Parent(s): ba2c1c5

Update server1.py

Browse files
Files changed (1) hide show
  1. server1.py +92 -104
server1.py CHANGED
@@ -1,17 +1,18 @@
1
  # app.py
2
- # sin tildes / sin enye
3
 
4
- import os, io, traceback, time
5
- from typing import Optional, List, Dict
6
  import torch
7
  from fastapi import FastAPI, File, UploadFile, Request
8
  from fastapi.responses import JSONResponse
9
  from PIL import Image, UnidentifiedImageError, ImageFile
10
  from torchvision import transforms as T
 
11
 
12
  ImageFile.LOAD_TRUNCATED_IMAGES = True
13
 
14
- # ===== caches (ruta escribible) =====
15
  CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache")
16
  os.environ["XDG_CACHE_HOME"] = CACHE_ROOT
17
  os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf")
@@ -23,34 +24,38 @@ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
23
  os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True)
24
  os.makedirs(os.environ["TORCH_HOME"], exist_ok=True)
25
 
26
- import open_clip # importar tras setear caches
27
 
28
  # ===== limites basicos =====
29
- CPU_THREADS = int(os.environ.get("CPU_THREADS", max(1, min(8, os.cpu_count() or 1))))
30
- torch.set_num_threads(CPU_THREADS)
31
- os.environ["OMP_NUM_THREADS"] = str(CPU_THREADS)
32
- os.environ["MKL_NUM_THREADS"] = str(CPU_THREADS)
 
 
 
 
 
 
33
 
34
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
36
  if DEVICE == "cuda":
37
  torch.set_float32_matmul_precision("high")
38
 
39
- # ===== rutas a embeddings (compatibles con bigG/laion2b_s39b_b160k) =====
40
  MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt")
41
  VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt")
42
 
43
- # ===== modelo OpenCLIP bigG (el mismo con el que generaste los .pt) =====
44
- MODEL_NAME = "ViT-bigG-14"
45
- PRETRAINED = "laion2b_s39b_b160k"
46
 
47
- # Tamaño de entrada (por defecto el del preprocess = 448)
48
- FAST_SIZE = int(os.environ.get("FAST_SIZE", "448")) # puedes dejar 448; no cambia la dimension de salida
49
 
50
- app = FastAPI(title="OpenCLIP bigG Vehicle API (fast-safe)")
51
-
52
- # ===== carga modelo / preprocess =====
53
  _ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
 
54
  if isinstance(_ret, tuple) and len(_ret) == 3:
55
  clip_model, _preprocess_train, preprocess = _ret
56
  else:
@@ -60,16 +65,17 @@ clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
60
  for p in clip_model.parameters():
61
  p.requires_grad = False
62
 
63
- # normalizacion & size desde preprocess (448). Permite bajar via FAST_SIZE sin tocar .pt
64
  normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize))
65
- DEFAULT_SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
66
- if isinstance(DEFAULT_SIZE, (tuple, list)):
67
- DEFAULT_SIZE = max(DEFAULT_SIZE)
68
- SIZE = min(DEFAULT_SIZE or 448, FAST_SIZE)
 
69
 
70
  transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)])
71
 
72
- # ===== utils imagen =====
73
  def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
74
  if img.mode != "RGB":
75
  img = img.convert("RGB")
@@ -83,7 +89,7 @@ def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
83
  canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2))
84
  return canvas
85
 
86
- # ===== cargar embeddings =====
87
  def _ensure_label_list(x):
88
  if isinstance(x, (list, tuple)):
89
  return list(x)
@@ -98,51 +104,37 @@ def _load_embeddings(path: str):
98
  embeds = embeds / embeds.norm(dim=-1, keepdim=True)
99
  return labels, embeds
100
 
101
- model_labels, model_embeddings_cpu = _load_embeddings(MODEL_EMB_PATH)
102
- version_labels, version_embeddings_cpu = _load_embeddings(VERS_EMB_PATH)
103
-
104
- # mover a device una vez y cachear transpuestas
105
- model_embeddings_dev = model_embeddings_cpu.to(device=DEVICE, dtype=DTYPE).contiguous()
106
- version_embeddings_dev = version_embeddings_cpu.to(device=DEVICE, dtype=DTYPE).contiguous()
107
- model_embeddings_T = model_embeddings_dev.t().contiguous()
108
 
109
- # comprobar dimension de imagen vs textos
110
  with torch.inference_mode():
111
  dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE)
112
  img_dim = clip_model.encode_image(dummy).shape[-1]
113
- if model_embeddings_dev.shape[1] != img_dim or version_embeddings_dev.shape[1] != img_dim:
114
  raise RuntimeError(
115
- f"dimension mismatch: image={img_dim}, modelos={model_embeddings_dev.shape[1]}, "
116
- f"versiones={version_embeddings_dev.shape[1]}. Recalcula embeddings con {MODEL_NAME}/{PRETRAINED}."
117
  )
118
 
119
- # ===== indice de versiones por modelo (para evitar startswith por request) =====
120
- from collections import defaultdict
121
- idx_by_model: Dict[str, List[int]] = defaultdict(list)
122
-
123
- # ordenar modelos por longitud para hacer match correcto de prefijo
124
- models_by_len = sorted(model_labels, key=len, reverse=True)
125
- for j, lab in enumerate(version_labels):
126
- for m in models_by_len:
127
- if lab.startswith(m):
128
- idx_by_model[m].append(j)
129
- break
130
-
131
- # pre-cachear sub-matrices transpuestas por modelo
132
- ver_index = {}
133
- for m, idxs in idx_by_model.items():
134
- if idxs:
135
- embT = version_embeddings_dev[idxs].t().contiguous()
136
- labs = [version_labels[i] for i in idxs]
137
- ver_index[m] = (embT, labs)
138
- else:
139
- ver_index[m] = (None, [])
140
-
141
- # ===== warm-up (reduce el primer request frio) =====
142
- with torch.inference_mode():
143
- _ = clip_model.encode_image(torch.zeros(1,3,SIZE,SIZE, device=DEVICE, dtype=DTYPE))
144
-
145
- # ===== inferencia =====
146
  @torch.inference_mode()
147
  def _encode_pil(img: Image.Image) -> torch.Tensor:
148
  img = resize_letterbox(img, SIZE)
@@ -152,62 +144,58 @@ def _encode_pil(img: Image.Image) -> torch.Tensor:
152
  feats = clip_model.encode_image(tensor)
153
  return feats / feats.norm(dim=-1, keepdim=True)
154
 
155
- def _top1(text_feats_T: torch.Tensor, img_feat: torch.Tensor):
156
- sim = (img_feat @ text_feats_T)[0].float()
157
- val, idx = torch.topk(sim, k=1)
158
- conf = torch.softmax(val, dim=0)[0]
159
- return int(idx), float(conf)*100.0
160
 
161
  def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None):
162
  if not front_bytes or len(front_bytes) < 128:
163
  raise UnidentifiedImageError("imagen invalida")
164
 
165
  img_front = Image.open(io.BytesIO(front_bytes))
166
- feat = _encode_pil(img_front)
167
 
168
  if back_bytes:
169
  try:
170
  img_back = Image.open(io.BytesIO(back_bytes))
171
- feat_b = _encode_pil(img_back)
172
- feat = (feat + feat_b)
173
- feat = feat / feat.norm(dim=-1, keepdim=True)
174
  except Exception:
175
- pass
176
-
177
- # 1) modelo (top-1)
178
- idx_m, _ = _top1(model_embeddings_T, feat)
179
- modelo_full = model_labels[idx_m]
180
-
181
- # 2) version (solo dentro del subconjunto indexado)
182
- embT, labs = ver_index.get(modelo_full, (None, []))
183
- version_out = ""
184
- if embT is not None and len(labs) > 0:
185
- idx_v, conf_v = _top1(embT, feat)
186
- raw = labs[idx_v]
187
- prefix = modelo_full + " "
188
- ver = raw[len(prefix):] if raw.startswith(prefix) else raw
189
- ver = ver.split(" ")[0]
190
- if conf_v >= 30.0:
191
- version_out = ver.title()
192
-
193
- parts = modelo_full.split(" ", 1)
194
- marca = parts[0] if len(parts) >= 1 else ""
195
- modelo = parts[1] if len(parts) == 2 else ""
196
-
197
- return {"brand": marca.upper(), "model": modelo.title(), "version": version_out}
 
 
 
 
198
 
199
  # ===== endpoints =====
200
  @app.get("/")
201
  def root():
202
- return {
203
- "status": "ok",
204
- "device": DEVICE,
205
- "dtype": str(DTYPE),
206
- "model": f"{MODEL_NAME}/{PRETRAINED}",
207
- "size": SIZE,
208
- "img_dim": int(model_embeddings_dev.shape[1]),
209
- "threads": CPU_THREADS
210
- }
211
 
212
  @app.post("/predict/")
213
  async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None):
 
1
  # app.py
2
+ # comentarios sin tildes / sin enye
3
 
4
+ import os, io, traceback
5
+ from typing import Optional, List, Tuple
6
  import torch
7
  from fastapi import FastAPI, File, UploadFile, Request
8
  from fastapi.responses import JSONResponse
9
  from PIL import Image, UnidentifiedImageError, ImageFile
10
  from torchvision import transforms as T
11
+ from functools import lru_cache
12
 
13
  ImageFile.LOAD_TRUNCATED_IMAGES = True
14
 
15
+ # ===== caches (usar ruta propia, escribible en runtime) =====
16
  CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache")
17
  os.environ["XDG_CACHE_HOME"] = CACHE_ROOT
18
  os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf")
 
24
  os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True)
25
  os.makedirs(os.environ["TORCH_HOME"], exist_ok=True)
26
 
27
+ import open_clip # importar despues de ajustar caches
28
 
29
  # ===== limites basicos =====
30
+ # por defecto conservamos 1 hilo (tu baseline). Para probar mas:
31
+ # export NUM_THREADS=4 (o el valor que quieras) sin tocar codigo
32
+ NUM_THREADS = int(os.environ.get("NUM_THREADS", "1"))
33
+ torch.set_num_threads(NUM_THREADS)
34
+ os.environ["OMP_NUM_THREADS"] = str(NUM_THREADS)
35
+ os.environ["MKL_NUM_THREADS"] = str(NUM_THREADS)
36
+ try:
37
+ torch.set_num_interop_threads(1)
38
+ except Exception:
39
+ pass
40
 
41
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
42
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
43
  if DEVICE == "cuda":
44
  torch.set_float32_matmul_precision("high")
45
 
46
+ # ===== rutas a embeddings =====
47
  MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt")
48
  VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt")
49
 
50
+ # ===== modelo PE bigG =====
51
+ MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448"
52
+ PRETRAINED = None
53
 
54
+ app = FastAPI(title="OpenCLIP PE bigG Vehicle API")
 
55
 
56
+ # ===== modelo / preprocess =====
 
 
57
  _ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
58
+ # versiones de open_clip devuelven (model, preprocess_train, preprocess_val)
59
  if isinstance(_ret, tuple) and len(_ret) == 3:
60
  clip_model, _preprocess_train, preprocess = _ret
61
  else:
 
65
  for p in clip_model.parameters():
66
  p.requires_grad = False
67
 
68
+ # extraer normalizacion y size desde el preprocess devuelto
69
  normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize))
70
+ SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
71
+ if isinstance(SIZE, (tuple, list)):
72
+ SIZE = max(SIZE)
73
+ if SIZE is None:
74
+ SIZE = 448 # PE bigG es 448; fallback
75
 
76
  transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)])
77
 
78
+ # ===== utils imagen (sin cambios: letterbox + BICUBIC) =====
79
  def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
80
  if img.mode != "RGB":
81
  img = img.convert("RGB")
 
89
  canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2))
90
  return canvas
91
 
92
+ # ===== cargar embeddings (sin cambios) =====
93
  def _ensure_label_list(x):
94
  if isinstance(x, (list, tuple)):
95
  return list(x)
 
104
  embeds = embeds / embeds.norm(dim=-1, keepdim=True)
105
  return labels, embeds
106
 
107
+ model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH)
108
+ version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
 
 
 
 
 
109
 
110
+ # comprobar dimension (PE bigG mantiene 1280)
111
  with torch.inference_mode():
112
  dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE)
113
  img_dim = clip_model.encode_image(dummy).shape[-1]
114
+ if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim:
115
  raise RuntimeError(
116
+ f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, "
117
+ f"versiones={version_embeddings.shape[1]}. Recalcula embeddings con {MODEL_NAME}."
118
  )
119
 
120
+ # ===== cache perezosa de sub-embeddings por modelo_full =====
121
+ # no cambia precision; solo evita escanear version_labels en cada request
122
+ _versions_cache: dict[str, Tuple[List[str], torch.Tensor]] = {}
123
+
124
+ def _get_versions_subset(modelo_full: str) -> Tuple[List[str], Optional[torch.Tensor]]:
125
+ hit = _versions_cache.get(modelo_full)
126
+ if hit is not None:
127
+ return hit
128
+ idxs = [i for i, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
129
+ if not idxs:
130
+ _versions_cache[modelo_full] = ([], None)
131
+ return _versions_cache[modelo_full]
132
+ labels_sub = [version_labels[i] for i in idxs]
133
+ embeds_sub = version_embeddings[idxs] # copia de esas filas
134
+ _versions_cache[modelo_full] = (labels_sub, embeds_sub)
135
+ return _versions_cache[modelo_full]
136
+
137
+ # ===== inferencia (sin cambios de logica/precision) =====
 
 
 
 
 
 
 
 
 
138
  @torch.inference_mode()
139
  def _encode_pil(img: Image.Image) -> torch.Tensor:
140
  img = resize_letterbox(img, SIZE)
 
144
  feats = clip_model.encode_image(tensor)
145
  return feats / feats.norm(dim=-1, keepdim=True)
146
 
147
+ def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1):
148
+ sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0]
149
+ vals, idxs = torch.topk(sim, k=k)
150
+ conf = torch.softmax(vals, dim=0)
151
+ return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)]
152
 
153
  def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None):
154
  if not front_bytes or len(front_bytes) < 128:
155
  raise UnidentifiedImageError("imagen invalida")
156
 
157
  img_front = Image.open(io.BytesIO(front_bytes))
158
+ feat_front = _encode_pil(img_front)
159
 
160
  if back_bytes:
161
  try:
162
  img_back = Image.open(io.BytesIO(back_bytes))
163
+ feat_back = _encode_pil(img_back)
164
+ img_feat = (feat_front + feat_back) / 2
165
+ img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
166
  except Exception:
167
+ img_feat = feat_front
168
+ else:
169
+ img_feat = feat_front
170
+
171
+ # paso 1: modelo
172
+ top_model = _topk_cosine(model_embeddings, model_labels, img_feat, k=1)[0]
173
+ modelo_full = top_model["label"]
174
+
175
+ partes = modelo_full.split(" ", 1)
176
+ marca = partes[0] if len(partes) >= 1 else ""
177
+ modelo = partes[1] if len(partes) == 2 else ""
178
+
179
+ # paso 2: versiones con cache (misma logica, sin bucle global cada vez)
180
+ labels_sub, embeds_sub = _get_versions_subset(modelo_full)
181
+ if not labels_sub:
182
+ return {"brand": marca.upper(), "model": modelo.title(), "version": ""}
183
+
184
+ # paso 3: version
185
+ top_ver = _topk_cosine(embeds_sub, labels_sub, img_feat, k=1)[0]
186
+ raw = top_ver["label"]
187
+ prefix = modelo_full + " "
188
+ ver = raw[len(prefix):] if raw.startswith(prefix) else raw
189
+ ver = ver.split(" ")[0]
190
+ if top_ver["confidence"] < 30.0:
191
+ ver = ""
192
+
193
+ return {"brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else ""}
194
 
195
  # ===== endpoints =====
196
  @app.get("/")
197
  def root():
198
+ return {"status": "ok", "device": DEVICE, "model": f"{MODEL_NAME}", "img_dim": int(model_embeddings.shape[1]), "threads": NUM_THREADS}
 
 
 
 
 
 
 
 
199
 
200
  @app.post("/predict/")
201
  async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None):