Spaces:
Sleeping
Sleeping
File size: 5,994 Bytes
cebb2ac 31ffd52 cebb2ac 31ffd52 cebb2ac 31ffd52 6a5790f cebb2ac 6a5790f cebb2ac 6a5790f cebb2ac 31ffd52 cebb2ac 6a5790f 31ffd52 cebb2ac 6a5790f cebb2ac 6a5790f cebb2ac 31ffd52 cebb2ac 6a5790f cebb2ac 6a5790f cebb2ac 31ffd52 cebb2ac 6a5790f cebb2ac 6a5790f cebb2ac 6a5790f cebb2ac eb8afd8 6a5790f cebb2ac 3868453 6a5790f 3868453 cebb2ac 3868453 6a5790f 3868453 6a5790f 31ffd52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# app.py
# comentarios sin tildes / sin enye
import os, io, traceback
from typing import Optional
import torch
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
from PIL import Image, UnidentifiedImageError
import open_clip
from torchvision import transforms as T
# caches locales (evitar permisos en /)
os.environ.setdefault("HF_HOME", "/app/cache")
os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/app/cache/huggingface")
os.environ.setdefault("TORCH_HOME", "/app/cache/torch")
os.makedirs("/app/cache", exist_ok=True)
# limites basicos
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# rutas a embeddings
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
app = FastAPI(title="CLIP H14 Vehicle API")
# ============== modelo CLIP ==============
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-H-14", pretrained="laion2b_s32b_b79k"
)
clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
for p in clip_model.parameters():
p.requires_grad = False
normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize))
transform = T.Compose([
T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=normalize.mean, std=normalize.std),
])
# ============== embeddings ==============
def _ensure_label_list(x):
if isinstance(x, (list, tuple)):
return list(x)
if hasattr(x, "tolist"):
return [str(s) for s in x.tolist()]
return [str(s) for s in x]
def _load_embeddings(path: str):
ckpt = torch.load(path, map_location="cpu")
labels = _ensure_label_list(ckpt["labels"])
embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde
embeds = embeds / embeds.norm(dim=-1, keepdim=True)
return labels, embeds
model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH)
version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
# ============== inferencia ==============
@torch.inference_mode()
def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
if DEVICE == "cuda":
with torch.cuda.amp.autocast(dtype=DTYPE):
feats = clip_model.encode_image(img_tensor)
else:
feats = clip_model.encode_image(img_tensor)
return feats / feats.norm(dim=-1, keepdim=True)
def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
img_f = _encode_image(image_tensor)
# asegurar mismo device y dtype
text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
vals, idxs = torch.topk(sim, k=topk)
return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)]
def process_image_bytes(image_bytes: bytes):
# devuelve solo el dict vehicle: brand/model/version
if not image_bytes or len(image_bytes) < 128:
raise UnidentifiedImageError("imagen invalida")
img = Image.open(io.BytesIO(image_bytes))
if img.mode != "RGB":
img = img.convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
# paso 1: top-1 modelo
model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
modelo_full = top_model["label"]
partes = modelo_full.split(" ", 1)
marca = partes[0] if len(partes) >= 1 else ""
modelo = partes[1] if len(partes) == 2 else ""
# paso 2: filtrar versiones por prefijo
matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
if not matches:
return {
"brand": marca.upper(),
"model": modelo.title(),
"version": ""
}
idxs = [i for _, i in matches]
labels_sub = [lab for lab, _ in matches]
embeds_sub = version_embeddings[idxs].to(device=DEVICE, dtype=DTYPE)
# paso 3: top-1 version
top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
raw = top_ver["label"]
prefix = modelo_full + " "
ver = raw[len(prefix):] if raw.startswith(prefix) else raw
ver = ver.split(" ")[0]
# si baja confianza, no rellenamos version
if top_ver["confidence"] < 25.0:
ver = ""
return {
"brand": marca.upper(),
"model": modelo.title(),
"version": ver.title() if ver else ""
}
# ============== endpoints ==============
@app.get("/")
def root():
return {"status": "ok", "device": DEVICE}
@app.post("/predict/")
async def predict(front: UploadFile = File(None),
back: Optional[UploadFile] = File(None),
request: Request = None):
try:
if request:
print("headers:", dict(request.headers))
if front is None:
return JSONResponse(
content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"},
status_code=200
)
front_bytes = await front.read()
if back is not None:
_ = await back.read()
vehicle = process_image_bytes(front_bytes)
return JSONResponse(
content={"code": 200, "data": {"vehicle": vehicle}},
status_code=200
)
except Exception as e:
print("EXCEPTION:", repr(e))
traceback.print_exc()
return JSONResponse(
content={"code": 404, "data": {}, "error": str(e)},
status_code=200
)
|