Spaces:
Sleeping
Sleeping
Update server1.py
Browse files- server1.py +35 -20
server1.py
CHANGED
|
@@ -10,7 +10,7 @@ from PIL import Image, UnidentifiedImageError
|
|
| 10 |
import open_clip
|
| 11 |
from torchvision import transforms as T
|
| 12 |
|
| 13 |
-
# caches locales
|
| 14 |
os.environ.setdefault("HF_HOME", "/app/cache")
|
| 15 |
os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
|
| 16 |
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
|
|
@@ -26,6 +26,7 @@ os.environ["MKL_NUM_THREADS"] = "1"
|
|
| 26 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
| 28 |
|
|
|
|
| 29 |
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
|
| 30 |
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
|
| 31 |
|
|
@@ -57,7 +58,7 @@ def _ensure_label_list(x):
|
|
| 57 |
def _load_embeddings(path: str):
|
| 58 |
ckpt = torch.load(path, map_location="cpu")
|
| 59 |
labels = _ensure_label_list(ckpt["labels"])
|
| 60 |
-
embeds = ckpt["embeddings"].to("cpu")
|
| 61 |
embeds = embeds / embeds.norm(dim=-1, keepdim=True)
|
| 62 |
return labels, embeds
|
| 63 |
|
|
@@ -76,13 +77,14 @@ def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
|
|
| 76 |
|
| 77 |
def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
|
| 78 |
img_f = _encode_image(image_tensor)
|
| 79 |
-
#
|
| 80 |
text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
|
| 81 |
sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
|
| 82 |
vals, idxs = torch.topk(sim, k=topk)
|
| 83 |
-
return [{"label": text_labels[i], "confidence": round(float(v)*100.0, 2)} for v, i in zip(vals, idxs)]
|
| 84 |
|
| 85 |
def process_image_bytes(image_bytes: bytes):
|
|
|
|
| 86 |
if not image_bytes or len(image_bytes) < 128:
|
| 87 |
raise UnidentifiedImageError("imagen invalida")
|
| 88 |
|
|
@@ -95,7 +97,7 @@ def process_image_bytes(image_bytes: bytes):
|
|
| 95 |
# paso 1: top-1 modelo
|
| 96 |
model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
|
| 97 |
top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
|
| 98 |
-
modelo_full = top_model["label"]
|
| 99 |
|
| 100 |
partes = modelo_full.split(" ", 1)
|
| 101 |
marca = partes[0] if len(partes) >= 1 else ""
|
|
@@ -105,8 +107,9 @@ def process_image_bytes(image_bytes: bytes):
|
|
| 105 |
matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
|
| 106 |
if not matches:
|
| 107 |
return {
|
| 108 |
-
"
|
| 109 |
-
"
|
|
|
|
| 110 |
}
|
| 111 |
|
| 112 |
idxs = [i for _, i in matches]
|
|
@@ -115,20 +118,20 @@ def process_image_bytes(image_bytes: bytes):
|
|
| 115 |
|
| 116 |
# paso 3: top-1 version
|
| 117 |
top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
|
| 118 |
-
raw = top_ver["label"]
|
| 119 |
|
| 120 |
prefix = modelo_full + " "
|
| 121 |
ver = raw[len(prefix):] if raw.startswith(prefix) else raw
|
| 122 |
ver = ver.split(" ")[0]
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
|
| 126 |
return {
|
| 127 |
-
"
|
| 128 |
-
"
|
| 129 |
-
"version": ver.title() if ver else ""
|
| 130 |
-
"confianza_modelo": conf_m,
|
| 131 |
-
"confianza_version": conf_v
|
| 132 |
}
|
| 133 |
|
| 134 |
# ============== endpoints ==============
|
|
@@ -137,22 +140,34 @@ def root():
|
|
| 137 |
return {"status": "ok", "device": DEVICE}
|
| 138 |
|
| 139 |
@app.post("/predict")
|
| 140 |
-
async def predict(front: UploadFile = File(None),
|
|
|
|
|
|
|
| 141 |
try:
|
| 142 |
if request:
|
| 143 |
print("headers:", dict(request.headers))
|
| 144 |
if front is None:
|
| 145 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
front_bytes = await front.read()
|
| 148 |
if back is not None:
|
| 149 |
_ = await back.read()
|
| 150 |
|
| 151 |
-
|
| 152 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
print("EXCEPTION:", repr(e))
|
| 156 |
traceback.print_exc()
|
| 157 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
|
|
|
| 10 |
import open_clip
|
| 11 |
from torchvision import transforms as T
|
| 12 |
|
| 13 |
+
# caches locales (evitar permisos en /)
|
| 14 |
os.environ.setdefault("HF_HOME", "/app/cache")
|
| 15 |
os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
|
| 16 |
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
|
|
|
|
| 26 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 27 |
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
|
| 28 |
|
| 29 |
+
# rutas a embeddings
|
| 30 |
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
|
| 31 |
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
|
| 32 |
|
|
|
|
| 58 |
def _load_embeddings(path: str):
|
| 59 |
ckpt = torch.load(path, map_location="cpu")
|
| 60 |
labels = _ensure_label_list(ckpt["labels"])
|
| 61 |
+
embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde
|
| 62 |
embeds = embeds / embeds.norm(dim=-1, keepdim=True)
|
| 63 |
return labels, embeds
|
| 64 |
|
|
|
|
| 77 |
|
| 78 |
def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
|
| 79 |
img_f = _encode_image(image_tensor)
|
| 80 |
+
# asegurar mismo device y dtype
|
| 81 |
text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
|
| 82 |
sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
|
| 83 |
vals, idxs = torch.topk(sim, k=topk)
|
| 84 |
+
return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)]
|
| 85 |
|
| 86 |
def process_image_bytes(image_bytes: bytes):
|
| 87 |
+
# devuelve solo el dict vehicle: brand/model/version
|
| 88 |
if not image_bytes or len(image_bytes) < 128:
|
| 89 |
raise UnidentifiedImageError("imagen invalida")
|
| 90 |
|
|
|
|
| 97 |
# paso 1: top-1 modelo
|
| 98 |
model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
|
| 99 |
top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
|
| 100 |
+
modelo_full = top_model["label"]
|
| 101 |
|
| 102 |
partes = modelo_full.split(" ", 1)
|
| 103 |
marca = partes[0] if len(partes) >= 1 else ""
|
|
|
|
| 107 |
matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
|
| 108 |
if not matches:
|
| 109 |
return {
|
| 110 |
+
"brand": marca.upper(),
|
| 111 |
+
"model": modelo.title(),
|
| 112 |
+
"version": ""
|
| 113 |
}
|
| 114 |
|
| 115 |
idxs = [i for _, i in matches]
|
|
|
|
| 118 |
|
| 119 |
# paso 3: top-1 version
|
| 120 |
top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
|
| 121 |
+
raw = top_ver["label"]
|
| 122 |
|
| 123 |
prefix = modelo_full + " "
|
| 124 |
ver = raw[len(prefix):] if raw.startswith(prefix) else raw
|
| 125 |
ver = ver.split(" ")[0]
|
| 126 |
+
|
| 127 |
+
# si baja confianza, no rellenamos version
|
| 128 |
+
if top_ver["confidence"] < 25.0:
|
| 129 |
+
ver = ""
|
| 130 |
|
| 131 |
return {
|
| 132 |
+
"brand": marca.upper(),
|
| 133 |
+
"model": modelo.title(),
|
| 134 |
+
"version": ver.title() if ver else ""
|
|
|
|
|
|
|
| 135 |
}
|
| 136 |
|
| 137 |
# ============== endpoints ==============
|
|
|
|
| 140 |
return {"status": "ok", "device": DEVICE}
|
| 141 |
|
| 142 |
@app.post("/predict")
|
| 143 |
+
async def predict(front: UploadFile = File(None),
|
| 144 |
+
back: Optional[UploadFile] = File(None),
|
| 145 |
+
request: Request = None):
|
| 146 |
try:
|
| 147 |
if request:
|
| 148 |
print("headers:", dict(request.headers))
|
| 149 |
if front is None:
|
| 150 |
+
return JSONResponse(
|
| 151 |
+
content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"},
|
| 152 |
+
status_code=200
|
| 153 |
+
)
|
| 154 |
|
| 155 |
front_bytes = await front.read()
|
| 156 |
if back is not None:
|
| 157 |
_ = await back.read()
|
| 158 |
|
| 159 |
+
vehicle = process_image_bytes(front_bytes)
|
| 160 |
+
return JSONResponse(
|
| 161 |
+
content={"code": 200, "data": {"vehicle": vehicle}},
|
| 162 |
+
status_code=200
|
| 163 |
+
)
|
| 164 |
|
| 165 |
except Exception as e:
|
| 166 |
print("EXCEPTION:", repr(e))
|
| 167 |
traceback.print_exc()
|
| 168 |
+
return JSONResponse(
|
| 169 |
+
content={"code": 404, "data": {}, "error": str(e)},
|
| 170 |
+
status_code=200
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
|