File size: 8,013 Bytes
6bc6697
918f35b
5f5b190
 
da1f92c
918f35b
 
6bc6697
00e029a
5f5b190
918f35b
6bc6697
 
33c7b5c
 
 
 
 
 
 
 
 
 
 
5f5b190
33c7b5c
00e029a
5f5b190
 
 
 
 
 
 
 
00e029a
da1f92c
918f35b
6bc6697
 
918f35b
5f5b190
6bc6697
 
918f35b
5f5b190
 
 
918f35b
5f5b190
da1f92c
5f5b190
3dd5b5d
5f5b190
3dd5b5d
 
 
 
 
6bc6697
 
da1f92c
34e0af9
00e029a
5f5b190
 
 
 
 
9d4c81c
00e029a
9d4c81c
5f5b190
00e029a
 
 
 
 
 
 
 
 
 
 
 
6bc6697
5f5b190
918f35b
 
 
 
 
 
 
 
 
 
 
 
 
 
5f5b190
 
918f35b
5f5b190
6bc6697
00e029a
 
5f5b190
6bc6697
5f5b190
 
6bc6697
 
5f5b190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918f35b
6bc6697
00e029a
 
 
 
 
 
918f35b
5f5b190
 
 
 
 
918f35b
6bc6697
 
 
918f35b
6bc6697
6ed27b2
5f5b190
 
 
 
 
 
 
 
 
6ed27b2
5f5b190
 
 
 
 
 
 
 
 
 
 
 
 
 
9000130
6ed27b2
33c7b5c
2144a7a
 
5f5b190
da1f92c
 
6bc6697
2144a7a
918f35b
6bc6697
918f35b
6bc6697
 
 
 
 
 
918f35b
6bc6697
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# app.py

import os, io, traceback
from typing import Optional, List, Tuple
import torch
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
from PIL import Image, UnidentifiedImageError, ImageFile
from torchvision import transforms as T
from functools import lru_cache

ImageFile.LOAD_TRUNCATED_IMAGES = True

CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache")
os.environ["XDG_CACHE_HOME"] = CACHE_ROOT
os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf")
os.environ["HUGGINGFACE_HUB_CACHE"] = os.environ["HF_HOME"]
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"]
os.environ["OPENCLIP_CACHE_DIR"] = os.path.join(CACHE_ROOT, "open_clip")
os.environ["TORCH_HOME"] = os.path.join(CACHE_ROOT, "torch")
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True)
os.makedirs(os.environ["TORCH_HOME"], exist_ok=True)

import open_clip  # importar despues de ajustar caches

# ===== limites basicos =====
NUM_THREADS = int(os.environ.get("NUM_THREADS", "1"))
torch.set_num_threads(NUM_THREADS)
os.environ["OMP_NUM_THREADS"] = str(NUM_THREADS)
os.environ["MKL_NUM_THREADS"] = str(NUM_THREADS)
try:
    torch.set_num_interop_threads(1)
except Exception:
    pass

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE  = torch.float16 if DEVICE == "cuda" else torch.float32
if DEVICE == "cuda":
    torch.set_float32_matmul_precision("high")

# ===== rutas a embeddings =====
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt")
VERS_EMB_PATH  = os.getenv("VERS_EMB_PATH",  "text_embeddings_bigg.pt")

# ===== modelo PE bigG =====
MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448"
PRETRAINED = None

app = FastAPI(title="OpenCLIP PE bigG Vehicle API")

# ===== modelo / preprocess =====
_ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
# versiones de open_clip devuelven (model, preprocess_train, preprocess_val)
if isinstance(_ret, tuple) and len(_ret) == 3:
    clip_model, _preprocess_train, preprocess = _ret
else:
    clip_model, preprocess = _ret

clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
for p in clip_model.parameters():
    p.requires_grad = False

normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize))
SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
if isinstance(SIZE, (tuple, list)):
    SIZE = max(SIZE)
if SIZE is None:
    SIZE = 448  # PE bigG es 448; fallback

transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)])

# ===== utils imagen (sin cambios: letterbox + BICUBIC) =====
def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
    if img.mode != "RGB":
        img = img.convert("RGB")
    w, h = img.size
    if w == 0 or h == 0:
        raise UnidentifiedImageError("imagen invalida")
    scale = size / max(w, h)
    nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
    img_resized = img.resize((nw, nh), Image.BICUBIC)
    canvas = Image.new("RGB", (size, size), (0, 0, 0))
    canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2))
    return canvas

# ===== cargar embeddings (sin cambios) =====
def _ensure_label_list(x):
    if isinstance(x, (list, tuple)):
        return list(x)
    if hasattr(x, "tolist"):
        return [str(s) for s in x.tolist()]
    return [str(s) for s in x]

def _load_embeddings(path: str):
    ckpt = torch.load(path, map_location="cpu")
    labels = _ensure_label_list(ckpt["labels"])
    embeds = ckpt["embeddings"].to("cpu")
    embeds = embeds / embeds.norm(dim=-1, keepdim=True)
    return labels, embeds

model_labels, model_embeddings     = _load_embeddings(MODEL_EMB_PATH)
version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)

# comprobar dimension (PE bigG mantiene 1280)
with torch.inference_mode():
    dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE)
    img_dim = clip_model.encode_image(dummy).shape[-1]
if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim:
    raise RuntimeError(
        f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, "
        f"versiones={version_embeddings.shape[1]}. Recalcula embeddings con {MODEL_NAME}."
    )

_versions_cache: dict[str, Tuple[List[str], torch.Tensor]] = {}

def _get_versions_subset(modelo_full: str) -> Tuple[List[str], Optional[torch.Tensor]]:
    hit = _versions_cache.get(modelo_full)
    if hit is not None:
        return hit
    idxs = [i for i, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
    if not idxs:
        _versions_cache[modelo_full] = ([], None)
        return _versions_cache[modelo_full]
    labels_sub = [version_labels[i] for i in idxs]
    embeds_sub = version_embeddings[idxs]  # copia de esas filas
    _versions_cache[modelo_full] = (labels_sub, embeds_sub)
    return _versions_cache[modelo_full]

# ===== inferencia (sin cambios de logica/precision) =====
@torch.inference_mode()
def _encode_pil(img: Image.Image) -> torch.Tensor:
    img = resize_letterbox(img, SIZE)
    tensor = transform(img).unsqueeze(0).to(device=DEVICE)
    if DEVICE == "cuda":
        tensor = tensor.to(dtype=DTYPE)
    feats = clip_model.encode_image(tensor)
    return feats / feats.norm(dim=-1, keepdim=True)

def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1):
    sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0]
    vals, idxs = torch.topk(sim, k=k)
    conf = torch.softmax(vals, dim=0)
    return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)]

def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None):
    if not front_bytes or len(front_bytes) < 128:
        raise UnidentifiedImageError("imagen invalida")

    img_front = Image.open(io.BytesIO(front_bytes))
    img_feat = _encode_pil(img_front)

    # paso 1: modelo
    top_model = _topk_cosine(model_embeddings, model_labels, img_feat, k=1)[0]
    modelo_full = top_model["label"]

    partes = modelo_full.split(" ", 1)
    marca  = partes[0] if len(partes) >= 1 else ""
    modelo = partes[1] if len(partes) == 2 else ""

    # paso 2: versiones con cache
    labels_sub, embeds_sub = _get_versions_subset(modelo_full)
    if not labels_sub:
        return {"brand": marca.upper(), "model": modelo.title(), "version": ""}

    # paso 3: version
    top_ver = _topk_cosine(embeds_sub, labels_sub, img_feat, k=1)[0]
    raw = top_ver["label"]
    prefix = modelo_full + " "
    ver = raw[len(prefix):] if raw.startswith(prefix) else raw
    ver = ver.split(" ")[0]
    if top_ver["confidence"] < 30.0:
        ver = ""

    return {"brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else ""}


# ===== endpoints =====
@app.get("/")
def root():
    return {"status": "ok", "device": DEVICE, "model": f"{MODEL_NAME}", "img_dim": int(model_embeddings.shape[1]), "threads": NUM_THREADS}

@app.post("/predict/")
async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None):
    try:
        if front is None:
            return JSONResponse(content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200)
        front_bytes = await front.read()
        back_bytes = await back.read() if back is not None else None
        vehicle = process_image_bytes(front_bytes, back_bytes)
        return JSONResponse(content={"code": 200, "data": {"vehicle": vehicle}}, status_code=200)
    except (UnidentifiedImageError, OSError, RuntimeError, ValueError) as e:
        return JSONResponse(content={"code": 404, "data": {}, "error": str(e)}, status_code=200)
    except Exception:
        traceback.print_exc()
        return JSONResponse(content={"code": 404, "data": {}, "error": "internal"}, status_code=200)