Spaces:
Sleeping
Sleeping
| # app.py | |
| # comentarios sin tildes / sin enye | |
| import os, io, traceback | |
| from typing import Optional | |
| import torch | |
| from fastapi import FastAPI, File, UploadFile, Request | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image, UnidentifiedImageError | |
| import open_clip | |
| from torchvision import transforms as T | |
| # caches locales (evitar permisos en /) | |
| os.environ.setdefault("HF_HOME", "/app/cache") | |
| os.environ.setdefault("XDG_CACHE_HOME", "/app/cache") | |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/app/cache/huggingface") | |
| os.environ.setdefault("TORCH_HOME", "/app/cache/torch") | |
| os.makedirs("/app/cache", exist_ok=True) | |
| # limites basicos | |
| torch.set_num_threads(1) | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # rutas a embeddings | |
| MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt") | |
| VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt") | |
| app = FastAPI(title="CLIP H14 Vehicle API") | |
| # ============== modelo CLIP ============== | |
| clip_model, _, preprocess = open_clip.create_model_and_transforms( | |
| "ViT-H-14", pretrained="laion2b_s32b_b79k" | |
| ) | |
| clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval() | |
| for p in clip_model.parameters(): | |
| p.requires_grad = False | |
| normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize)) | |
| transform = T.Compose([ | |
| T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=normalize.mean, std=normalize.std), | |
| ]) | |
| # ============== embeddings ============== | |
| def _ensure_label_list(x): | |
| if isinstance(x, (list, tuple)): | |
| return list(x) | |
| if hasattr(x, "tolist"): | |
| return [str(s) for s in x.tolist()] | |
| return [str(s) for s in x] | |
| def _load_embeddings(path: str): | |
| ckpt = torch.load(path, map_location="cpu") | |
| labels = _ensure_label_list(ckpt["labels"]) | |
| embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde | |
| embeds = embeds / embeds.norm(dim=-1, keepdim=True) | |
| return labels, embeds | |
| model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH) | |
| version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH) | |
| # ============== inferencia ============== | |
| def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor: | |
| if DEVICE == "cuda": | |
| with torch.cuda.amp.autocast(dtype=DTYPE): | |
| feats = clip_model.encode_image(img_tensor) | |
| else: | |
| feats = clip_model.encode_image(img_tensor) | |
| return feats / feats.norm(dim=-1, keepdim=True) | |
| def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1): | |
| img_f = _encode_image(image_tensor) | |
| # asegurar mismo device y dtype | |
| text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype) | |
| sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0] | |
| vals, idxs = torch.topk(sim, k=topk) | |
| return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)] | |
| def process_image_bytes(image_bytes: bytes): | |
| # devuelve solo el dict vehicle: brand/model/version | |
| if not image_bytes or len(image_bytes) < 128: | |
| raise UnidentifiedImageError("imagen invalida") | |
| img = Image.open(io.BytesIO(image_bytes)) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE) | |
| # paso 1: top-1 modelo | |
| model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE) | |
| top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0] | |
| modelo_full = top_model["label"] | |
| partes = modelo_full.split(" ", 1) | |
| marca = partes[0] if len(partes) >= 1 else "" | |
| modelo = partes[1] if len(partes) == 2 else "" | |
| # paso 2: filtrar versiones por prefijo | |
| matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)] | |
| if not matches: | |
| return { | |
| "brand": marca.upper(), | |
| "model": modelo.title(), | |
| "version": "" | |
| } | |
| idxs = [i for _, i in matches] | |
| labels_sub = [lab for lab, _ in matches] | |
| embeds_sub = version_embeddings[idxs].to(device=DEVICE, dtype=DTYPE) | |
| # paso 3: top-1 version | |
| top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0] | |
| raw = top_ver["label"] | |
| prefix = modelo_full + " " | |
| ver = raw[len(prefix):] if raw.startswith(prefix) else raw | |
| ver = ver.split(" ")[0] | |
| # si baja confianza, no rellenamos version | |
| if top_ver["confidence"] < 25.0: | |
| ver = "" | |
| return { | |
| "brand": marca.upper(), | |
| "model": modelo.title(), | |
| "version": ver.title() if ver else "" | |
| } | |
| # ============== endpoints ============== | |
| def root(): | |
| return {"status": "ok", "device": DEVICE} | |
| async def predict(front: UploadFile = File(None), | |
| back: Optional[UploadFile] = File(None), | |
| request: Request = None): | |
| try: | |
| if request: | |
| print("headers:", dict(request.headers)) | |
| if front is None: | |
| return JSONResponse( | |
| content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"}, | |
| status_code=200 | |
| ) | |
| front_bytes = await front.read() | |
| if back is not None: | |
| _ = await back.read() | |
| vehicle = process_image_bytes(front_bytes) | |
| return JSONResponse( | |
| content={"code": 200, "data": {"vehicle": vehicle}}, | |
| status_code=200 | |
| ) | |
| except Exception as e: | |
| print("EXCEPTION:", repr(e)) | |
| traceback.print_exc() | |
| return JSONResponse( | |
| content={"code": 404, "data": {}, "error": str(e)}, | |
| status_code=200 | |
| ) | |