afdx2 commited on
Commit
6a5790f
·
verified ·
1 Parent(s): 31ffd52

Update server1.py

Browse files
Files changed (1) hide show
  1. server1.py +35 -20
server1.py CHANGED
@@ -10,7 +10,7 @@ from PIL import Image, UnidentifiedImageError
10
  import open_clip
11
  from torchvision import transforms as T
12
 
13
- # caches locales
14
  os.environ.setdefault("HF_HOME", "/app/cache")
15
  os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
16
  os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
@@ -26,6 +26,7 @@ os.environ["MKL_NUM_THREADS"] = "1"
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
 
 
29
  MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
30
  VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
31
 
@@ -57,7 +58,7 @@ def _ensure_label_list(x):
57
  def _load_embeddings(path: str):
58
  ckpt = torch.load(path, map_location="cpu")
59
  labels = _ensure_label_list(ckpt["labels"])
60
- embeds = ckpt["embeddings"].to("cpu") # guardados como fp16
61
  embeds = embeds / embeds.norm(dim=-1, keepdim=True)
62
  return labels, embeds
63
 
@@ -76,13 +77,14 @@ def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
76
 
77
  def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
78
  img_f = _encode_image(image_tensor)
79
- # casteamos embeddings al mismo dtype que la imagen
80
  text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
81
  sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
82
  vals, idxs = torch.topk(sim, k=topk)
83
- return [{"label": text_labels[i], "confidence": round(float(v)*100.0, 2)} for v, i in zip(vals, idxs)]
84
 
85
  def process_image_bytes(image_bytes: bytes):
 
86
  if not image_bytes or len(image_bytes) < 128:
87
  raise UnidentifiedImageError("imagen invalida")
88
 
@@ -95,7 +97,7 @@ def process_image_bytes(image_bytes: bytes):
95
  # paso 1: top-1 modelo
96
  model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
97
  top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
98
- modelo_full = top_model["label"]; conf_m = top_model["confidence"]
99
 
100
  partes = modelo_full.split(" ", 1)
101
  marca = partes[0] if len(partes) >= 1 else ""
@@ -105,8 +107,9 @@ def process_image_bytes(image_bytes: bytes):
105
  matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
106
  if not matches:
107
  return {
108
- "marca": marca.upper(), "modelo": modelo.title(),
109
- "version": "", "confianza_modelo": conf_m, "confianza_version": 0.0
 
110
  }
111
 
112
  idxs = [i for _, i in matches]
@@ -115,20 +118,20 @@ def process_image_bytes(image_bytes: bytes):
115
 
116
  # paso 3: top-1 version
117
  top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
118
- raw = top_ver["label"]; conf_v = top_ver["confidence"]
119
 
120
  prefix = modelo_full + " "
121
  ver = raw[len(prefix):] if raw.startswith(prefix) else raw
122
  ver = ver.split(" ")[0]
123
- if conf_v < 25.0:
124
- ver = "Version no identificada con suficiente confianza"
 
 
125
 
126
  return {
127
- "marca": marca.upper(),
128
- "modelo": modelo.title(),
129
- "version": ver.title() if ver else "",
130
- "confianza_modelo": conf_m,
131
- "confianza_version": conf_v
132
  }
133
 
134
  # ============== endpoints ==============
@@ -137,22 +140,34 @@ def root():
137
  return {"status": "ok", "device": DEVICE}
138
 
139
  @app.post("/predict")
140
- async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None):
 
 
141
  try:
142
  if request:
143
  print("headers:", dict(request.headers))
144
  if front is None:
145
- return JSONResponse(content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200)
 
 
 
146
 
147
  front_bytes = await front.read()
148
  if back is not None:
149
  _ = await back.read()
150
 
151
- result = process_image_bytes(front_bytes)
152
- return JSONResponse(content={"code": 200, "data": result})
 
 
 
153
 
154
  except Exception as e:
155
  print("EXCEPTION:", repr(e))
156
  traceback.print_exc()
157
- return JSONResponse(content={"code": 404, "data": {}, "error": str(e)}, status_code=200)
 
 
 
 
158
 
 
10
  import open_clip
11
  from torchvision import transforms as T
12
 
13
+ # caches locales (evitar permisos en /)
14
  os.environ.setdefault("HF_HOME", "/app/cache")
15
  os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
16
  os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
 
29
+ # rutas a embeddings
30
  MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
31
  VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
32
 
 
58
  def _load_embeddings(path: str):
59
  ckpt = torch.load(path, map_location="cpu")
60
  labels = _ensure_label_list(ckpt["labels"])
61
+ embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde
62
  embeds = embeds / embeds.norm(dim=-1, keepdim=True)
63
  return labels, embeds
64
 
 
77
 
78
  def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
79
  img_f = _encode_image(image_tensor)
80
+ # asegurar mismo device y dtype
81
  text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
82
  sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
83
  vals, idxs = torch.topk(sim, k=topk)
84
+ return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)]
85
 
86
  def process_image_bytes(image_bytes: bytes):
87
+ # devuelve solo el dict vehicle: brand/model/version
88
  if not image_bytes or len(image_bytes) < 128:
89
  raise UnidentifiedImageError("imagen invalida")
90
 
 
97
  # paso 1: top-1 modelo
98
  model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
99
  top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
100
+ modelo_full = top_model["label"]
101
 
102
  partes = modelo_full.split(" ", 1)
103
  marca = partes[0] if len(partes) >= 1 else ""
 
107
  matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
108
  if not matches:
109
  return {
110
+ "brand": marca.upper(),
111
+ "model": modelo.title(),
112
+ "version": ""
113
  }
114
 
115
  idxs = [i for _, i in matches]
 
118
 
119
  # paso 3: top-1 version
120
  top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
121
+ raw = top_ver["label"]
122
 
123
  prefix = modelo_full + " "
124
  ver = raw[len(prefix):] if raw.startswith(prefix) else raw
125
  ver = ver.split(" ")[0]
126
+
127
+ # si baja confianza, no rellenamos version
128
+ if top_ver["confidence"] < 25.0:
129
+ ver = ""
130
 
131
  return {
132
+ "brand": marca.upper(),
133
+ "model": modelo.title(),
134
+ "version": ver.title() if ver else ""
 
 
135
  }
136
 
137
  # ============== endpoints ==============
 
140
  return {"status": "ok", "device": DEVICE}
141
 
142
  @app.post("/predict")
143
+ async def predict(front: UploadFile = File(None),
144
+ back: Optional[UploadFile] = File(None),
145
+ request: Request = None):
146
  try:
147
  if request:
148
  print("headers:", dict(request.headers))
149
  if front is None:
150
+ return JSONResponse(
151
+ content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"},
152
+ status_code=200
153
+ )
154
 
155
  front_bytes = await front.read()
156
  if back is not None:
157
  _ = await back.read()
158
 
159
+ vehicle = process_image_bytes(front_bytes)
160
+ return JSONResponse(
161
+ content={"code": 200, "data": {"vehicle": vehicle}},
162
+ status_code=200
163
+ )
164
 
165
  except Exception as e:
166
  print("EXCEPTION:", repr(e))
167
  traceback.print_exc()
168
+ return JSONResponse(
169
+ content={"code": 404, "data": {}, "error": str(e)},
170
+ status_code=200
171
+ )
172
+
173