# app.py # comentarios sin tildes / sin enye import os, io, traceback from typing import Optional import torch from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import JSONResponse from PIL import Image, UnidentifiedImageError import open_clip from torchvision import transforms as T # caches locales (evitar permisos en /) os.environ.setdefault("HF_HOME", "/app/cache") os.environ.setdefault("XDG_CACHE_HOME", "/app/cache") os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface") os.environ.setdefault("TRANSFORMERS_CACHE", "/app/cache/huggingface") os.environ.setdefault("TORCH_HOME", "/app/cache/torch") os.makedirs("/app/cache", exist_ok=True) # limites basicos torch.set_num_threads(1) os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 # rutas a embeddings MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt") VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt") app = FastAPI(title="CLIP H14 Vehicle API") # ============== modelo CLIP ============== clip_model, _, preprocess = open_clip.create_model_and_transforms( "ViT-H-14", pretrained="laion2b_s32b_b79k" ) clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval() for p in clip_model.parameters(): p.requires_grad = False normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize)) transform = T.Compose([ T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std), ]) # ============== embeddings ============== def _ensure_label_list(x): if isinstance(x, (list, tuple)): return list(x) if hasattr(x, "tolist"): return [str(s) for s in x.tolist()] return [str(s) for s in x] def _load_embeddings(path: str): ckpt = torch.load(path, map_location="cpu") labels = _ensure_label_list(ckpt["labels"]) embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde embeds = embeds / embeds.norm(dim=-1, keepdim=True) return labels, embeds model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH) version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH) # ============== inferencia ============== @torch.inference_mode() def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor: if DEVICE == "cuda": with torch.cuda.amp.autocast(dtype=DTYPE): feats = clip_model.encode_image(img_tensor) else: feats = clip_model.encode_image(img_tensor) return feats / feats.norm(dim=-1, keepdim=True) def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1): img_f = _encode_image(image_tensor) # asegurar mismo device y dtype text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype) sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0] vals, idxs = torch.topk(sim, k=topk) return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)] def process_image_bytes(image_bytes: bytes): # devuelve solo el dict vehicle: brand/model/version if not image_bytes or len(image_bytes) < 128: raise UnidentifiedImageError("imagen invalida") img = Image.open(io.BytesIO(image_bytes)) if img.mode != "RGB": img = img.convert("RGB") img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE) # paso 1: top-1 modelo model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE) top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0] modelo_full = top_model["label"] partes = modelo_full.split(" ", 1) marca = partes[0] if len(partes) >= 1 else "" modelo = partes[1] if len(partes) == 2 else "" # paso 2: filtrar versiones por prefijo matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)] if not matches: return { "brand": marca.upper(), "model": modelo.title(), "version": "" } idxs = [i for _, i in matches] labels_sub = [lab for lab, _ in matches] embeds_sub = version_embeddings[idxs].to(device=DEVICE, dtype=DTYPE) # paso 3: top-1 version top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0] raw = top_ver["label"] prefix = modelo_full + " " ver = raw[len(prefix):] if raw.startswith(prefix) else raw ver = ver.split(" ")[0] # si baja confianza, no rellenamos version if top_ver["confidence"] < 25.0: ver = "" return { "brand": marca.upper(), "model": modelo.title(), "version": ver.title() if ver else "" } # ============== endpoints ============== @app.get("/") def root(): return {"status": "ok", "device": DEVICE} @app.post("/predict/") async def predict(front: UploadFile = File(None), back: Optional[UploadFile] = File(None), request: Request = None): try: if request: print("headers:", dict(request.headers)) if front is None: return JSONResponse( content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, status_code=200 ) front_bytes = await front.read() if back is not None: _ = await back.read() vehicle = process_image_bytes(front_bytes) return JSONResponse( content={"code": 200, "data": {"vehicle": vehicle}}, status_code=200 ) except Exception as e: print("EXCEPTION:", repr(e)) traceback.print_exc() return JSONResponse( content={"code": 404, "data": {}, "error": str(e)}, status_code=200 )