ffff / server1.py
afdx2's picture
Update server1.py
eb8afd8 verified
raw
history blame
5.99 kB
# app.py
# comentarios sin tildes / sin enye
import os, io, traceback
from typing import Optional
import torch
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import JSONResponse
from PIL import Image, UnidentifiedImageError
import open_clip
from torchvision import transforms as T
# caches locales (evitar permisos en /)
os.environ.setdefault("HF_HOME", "/app/cache")
os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/app/cache/huggingface")
os.environ.setdefault("TORCH_HOME", "/app/cache/torch")
os.makedirs("/app/cache", exist_ok=True)
# limites basicos
torch.set_num_threads(1)
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
# rutas a embeddings
MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
app = FastAPI(title="CLIP H14 Vehicle API")
# ============== modelo CLIP ==============
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-H-14", pretrained="laion2b_s32b_b79k"
)
clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
for p in clip_model.parameters():
p.requires_grad = False
normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize))
transform = T.Compose([
T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=normalize.mean, std=normalize.std),
])
# ============== embeddings ==============
def _ensure_label_list(x):
if isinstance(x, (list, tuple)):
return list(x)
if hasattr(x, "tolist"):
return [str(s) for s in x.tolist()]
return [str(s) for s in x]
def _load_embeddings(path: str):
ckpt = torch.load(path, map_location="cpu")
labels = _ensure_label_list(ckpt["labels"])
embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde
embeds = embeds / embeds.norm(dim=-1, keepdim=True)
return labels, embeds
model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH)
version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
# ============== inferencia ==============
@torch.inference_mode()
def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
if DEVICE == "cuda":
with torch.cuda.amp.autocast(dtype=DTYPE):
feats = clip_model.encode_image(img_tensor)
else:
feats = clip_model.encode_image(img_tensor)
return feats / feats.norm(dim=-1, keepdim=True)
def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
img_f = _encode_image(image_tensor)
# asegurar mismo device y dtype
text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
vals, idxs = torch.topk(sim, k=topk)
return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)]
def process_image_bytes(image_bytes: bytes):
# devuelve solo el dict vehicle: brand/model/version
if not image_bytes or len(image_bytes) < 128:
raise UnidentifiedImageError("imagen invalida")
img = Image.open(io.BytesIO(image_bytes))
if img.mode != "RGB":
img = img.convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
# paso 1: top-1 modelo
model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
modelo_full = top_model["label"]
partes = modelo_full.split(" ", 1)
marca = partes[0] if len(partes) >= 1 else ""
modelo = partes[1] if len(partes) == 2 else ""
# paso 2: filtrar versiones por prefijo
matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
if not matches:
return {
"brand": marca.upper(),
"model": modelo.title(),
"version": ""
}
idxs = [i for _, i in matches]
labels_sub = [lab for lab, _ in matches]
embeds_sub = version_embeddings[idxs].to(device=DEVICE, dtype=DTYPE)
# paso 3: top-1 version
top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
raw = top_ver["label"]
prefix = modelo_full + " "
ver = raw[len(prefix):] if raw.startswith(prefix) else raw
ver = ver.split(" ")[0]
# si baja confianza, no rellenamos version
if top_ver["confidence"] < 25.0:
ver = ""
return {
"brand": marca.upper(),
"model": modelo.title(),
"version": ver.title() if ver else ""
}
# ============== endpoints ==============
@app.get("/")
def root():
return {"status": "ok", "device": DEVICE}
@app.post("/predict/")
async def predict(front: UploadFile = File(None),
back: Optional[UploadFile] = File(None),
request: Request = None):
try:
if request:
print("headers:", dict(request.headers))
if front is None:
return JSONResponse(
content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"},
status_code=200
)
front_bytes = await front.read()
if back is not None:
_ = await back.read()
vehicle = process_image_bytes(front_bytes)
return JSONResponse(
content={"code": 200, "data": {"vehicle": vehicle}},
status_code=200
)
except Exception as e:
print("EXCEPTION:", repr(e))
traceback.print_exc()
return JSONResponse(
content={"code": 404, "data": {}, "error": str(e)},
status_code=200
)