File size: 8,013 Bytes
6bc6697 918f35b 5f5b190 da1f92c 918f35b 6bc6697 00e029a 5f5b190 918f35b 6bc6697 33c7b5c 5f5b190 33c7b5c 00e029a 5f5b190 00e029a da1f92c 918f35b 6bc6697 918f35b 5f5b190 6bc6697 918f35b 5f5b190 918f35b 5f5b190 da1f92c 5f5b190 3dd5b5d 5f5b190 3dd5b5d 6bc6697 da1f92c 34e0af9 00e029a 5f5b190 9d4c81c 00e029a 9d4c81c 5f5b190 00e029a 6bc6697 5f5b190 918f35b 5f5b190 918f35b 5f5b190 6bc6697 00e029a 5f5b190 6bc6697 5f5b190 6bc6697 5f5b190 918f35b 6bc6697 00e029a 918f35b 5f5b190 918f35b 6bc6697 918f35b 6bc6697 6ed27b2 5f5b190 6ed27b2 5f5b190 9000130 6ed27b2 33c7b5c 2144a7a 5f5b190 da1f92c 6bc6697 2144a7a 918f35b 6bc6697 918f35b 6bc6697 918f35b 6bc6697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# app.py
import os, io, traceback
from typing import Optional, List, Tuple
import torch
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
from PIL import Image, UnidentifiedImageError, ImageFile
from torchvision import transforms as T
from functools import lru_cache
ImageFile.LOAD_TRUNCATED_IMAGES = True
CACHE_ROOT = os.environ.get("APP_CACHE", "/tmp/appcache")
os.environ["XDG_CACHE_HOME"] = CACHE_ROOT
os.environ["HF_HOME"] = os.path.join(CACHE_ROOT, "hf")
os.environ["HUGGINGFACE_HUB_CACHE"] = os.environ["HF_HOME"]
os.environ["TRANSFORMERS_CACHE"] = os.environ["HF_HOME"]
os.environ["OPENCLIP_CACHE_DIR"] = os.path.join(CACHE_ROOT, "open_clip")
os.environ["TORCH_HOME"] = os.path.join(CACHE_ROOT, "torch")
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
os.makedirs(os.environ["OPENCLIP_CACHE_DIR"], exist_ok=True)
os.makedirs(os.environ["TORCH_HOME"], exist_ok=True)
import open_clip # importar despues de ajustar caches
# ===== limites basicos =====
NUM_THREADS = int(os.environ.get("NUM_THREADS", "1"))
torch.set_num_threads(NUM_THREADS)
os.environ["OMP_NUM_THREADS"] = str(NUM_THREADS)
os.environ["MKL_NUM_THREADS"] = str(NUM_THREADS)
try:
torch.set_num_interop_threads(1)
except Exception:
pass
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
if DEVICE == "cuda":
torch.set_float32_matmul_precision("high")
# ===== rutas a embeddings =====
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt")
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt")
# ===== modelo PE bigG =====
MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448"
PRETRAINED = None
app = FastAPI(title="OpenCLIP PE bigG Vehicle API")
# ===== modelo / preprocess =====
_ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
# versiones de open_clip devuelven (model, preprocess_train, preprocess_val)
if isinstance(_ret, tuple) and len(_ret) == 3:
clip_model, _preprocess_train, preprocess = _ret
else:
clip_model, preprocess = _ret
clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
for p in clip_model.parameters():
p.requires_grad = False
normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize))
SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
if isinstance(SIZE, (tuple, list)):
SIZE = max(SIZE)
if SIZE is None:
SIZE = 448 # PE bigG es 448; fallback
transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)])
# ===== utils imagen (sin cambios: letterbox + BICUBIC) =====
def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
if img.mode != "RGB":
img = img.convert("RGB")
w, h = img.size
if w == 0 or h == 0:
raise UnidentifiedImageError("imagen invalida")
scale = size / max(w, h)
nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
img_resized = img.resize((nw, nh), Image.BICUBIC)
canvas = Image.new("RGB", (size, size), (0, 0, 0))
canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2))
return canvas
# ===== cargar embeddings (sin cambios) =====
def _ensure_label_list(x):
if isinstance(x, (list, tuple)):
return list(x)
if hasattr(x, "tolist"):
return [str(s) for s in x.tolist()]
return [str(s) for s in x]
def _load_embeddings(path: str):
ckpt = torch.load(path, map_location="cpu")
labels = _ensure_label_list(ckpt["labels"])
embeds = ckpt["embeddings"].to("cpu")
embeds = embeds / embeds.norm(dim=-1, keepdim=True)
return labels, embeds
model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH)
version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
# comprobar dimension (PE bigG mantiene 1280)
with torch.inference_mode():
dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE)
img_dim = clip_model.encode_image(dummy).shape[-1]
if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim:
raise RuntimeError(
f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, "
f"versiones={version_embeddings.shape[1]}. Recalcula embeddings con {MODEL_NAME}."
)
_versions_cache: dict[str, Tuple[List[str], torch.Tensor]] = {}
def _get_versions_subset(modelo_full: str) -> Tuple[List[str], Optional[torch.Tensor]]:
hit = _versions_cache.get(modelo_full)
if hit is not None:
return hit
idxs = [i for i, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
if not idxs:
_versions_cache[modelo_full] = ([], None)
return _versions_cache[modelo_full]
labels_sub = [version_labels[i] for i in idxs]
embeds_sub = version_embeddings[idxs] # copia de esas filas
_versions_cache[modelo_full] = (labels_sub, embeds_sub)
return _versions_cache[modelo_full]
# ===== inferencia (sin cambios de logica/precision) =====
@torch.inference_mode()
def _encode_pil(img: Image.Image) -> torch.Tensor:
img = resize_letterbox(img, SIZE)
tensor = transform(img).unsqueeze(0).to(device=DEVICE)
if DEVICE == "cuda":
tensor = tensor.to(dtype=DTYPE)
feats = clip_model.encode_image(tensor)
return feats / feats.norm(dim=-1, keepdim=True)
def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1):
sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0]
vals, idxs = torch.topk(sim, k=k)
conf = torch.softmax(vals, dim=0)
return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)]
def process_image_bytes(front_bytes: bytes, back_bytes: Optional[bytes] = None):
if not front_bytes or len(front_bytes) < 128:
raise UnidentifiedImageError("imagen invalida")
img_front = Image.open(io.BytesIO(front_bytes))
img_feat = _encode_pil(img_front)
# paso 1: modelo
top_model = _topk_cosine(model_embeddings, model_labels, img_feat, k=1)[0]
modelo_full = top_model["label"]
partes = modelo_full.split(" ", 1)
marca = partes[0] if len(partes) >= 1 else ""
modelo = partes[1] if len(partes) == 2 else ""
# paso 2: versiones con cache
labels_sub, embeds_sub = _get_versions_subset(modelo_full)
if not labels_sub:
return {"brand": marca.upper(), "model": modelo.title(), "version": ""}
# paso 3: version
top_ver = _topk_cosine(embeds_sub, labels_sub, img_feat, k=1)[0]
raw = top_ver["label"]
prefix = modelo_full + " "
ver = raw[len(prefix):] if raw.startswith(prefix) else raw
ver = ver.split(" ")[0]
if top_ver["confidence"] < 30.0:
ver = ""
return {"brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else ""}
# ===== endpoints =====
@app.get("/")
def root():
return {"status": "ok", "device": DEVICE, "model": f"{MODEL_NAME}", "img_dim": int(model_embeddings.shape[1]), "threads": NUM_THREADS}
@app.post("/predict/")
async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None):
try:
if front is None:
return JSONResponse(content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200)
front_bytes = await front.read()
back_bytes = await back.read() if back is not None else None
vehicle = process_image_bytes(front_bytes, back_bytes)
return JSONResponse(content={"code": 200, "data": {"vehicle": vehicle}}, status_code=200)
except (UnidentifiedImageError, OSError, RuntimeError, ValueError) as e:
return JSONResponse(content={"code": 404, "data": {}, "error": str(e)}, status_code=200)
except Exception:
traceback.print_exc()
return JSONResponse(content={"code": 404, "data": {}, "error": "internal"}, status_code=200)
|