afdx2 commited on
Commit
cebb2ac
·
verified ·
1 Parent(s): 40304b5

Update server1.py

Browse files
Files changed (1) hide show
  1. server1.py +147 -139
server1.py CHANGED
@@ -1,139 +1,147 @@
1
- # app.py
2
- # comentarios sin tildes / sin enye
3
-
4
- import os, io
5
- from typing import Optional
6
- import torch
7
- from fastapi import FastAPI, File, UploadFile
8
- from fastapi.responses import JSONResponse
9
- from PIL import Image, UnidentifiedImageError
10
- import open_clip
11
- from torchvision import transforms as T
12
-
13
- # limites basicos
14
- torch.set_num_threads(1)
15
- os.environ["OMP_NUM_THREADS"] = "1"
16
- os.environ["MKL_NUM_THREADS"] = "1"
17
-
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
20
-
21
- # nombres de ficheros (en el mismo repo)
22
- MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
23
- VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
24
-
25
- app = FastAPI(title="CLIP H14 Vehicle API")
26
-
27
- # ============== modelo CLIP ==============
28
- clip_model, _, preprocess = open_clip.create_model_and_transforms(
29
- "ViT-H-14", pretrained="laion2b_s32b_b79k"
30
- )
31
- clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
32
- for p in clip_model.parameters():
33
- p.requires_grad = False
34
-
35
- normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize))
36
- transform = T.Compose([
37
- T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
38
- T.ToTensor(),
39
- T.Normalize(mean=normalize.mean, std=normalize.std),
40
- ])
41
-
42
- # ============== embeddings ==============
43
- def _ensure_label_list(x):
44
- if isinstance(x, (list, tuple)):
45
- return list(x)
46
- if hasattr(x, "tolist"):
47
- return [str(s) for s in x.tolist()]
48
- return [str(s) for s in x]
49
-
50
- def _load_embeddings(path: str):
51
- ckpt = torch.load(path, map_location="cpu")
52
- labels = _ensure_label_list(ckpt["labels"])
53
- embeds = ckpt["embeddings"].to("cpu", dtype=torch.float16)
54
- embeds = embeds / embeds.norm(dim=-1, keepdim=True)
55
- return labels, embeds
56
-
57
- model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH) # "Marca Modelo"
58
- version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH) # "Marca Modelo Version"
59
-
60
- # ============== inferencia ==============
61
- @torch.inference_mode()
62
- def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
63
- if DEVICE == "cuda":
64
- with torch.cuda.amp.autocast(dtype=DTYPE):
65
- feats = clip_model.encode_image(img_tensor)
66
- else:
67
- feats = clip_model.encode_image(img_tensor)
68
- return feats / feats.norm(dim=-1, keepdim=True)
69
-
70
- def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
71
- img_f = _encode_image(image_tensor)
72
- sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
73
- vals, idxs = torch.topk(sim, k=topk)
74
- return [{"label": text_labels[i], "confidence": round(float(v)*100.0, 2)} for v, i in zip(vals, idxs)]
75
-
76
- def process_image_bytes(image_bytes: bytes):
77
- if not image_bytes or len(image_bytes) < 128:
78
- raise UnidentifiedImageError("imagen invalida")
79
-
80
- img = Image.open(io.BytesIO(image_bytes))
81
- if img.mode != "RGB":
82
- img = img.convert("RGB")
83
-
84
- img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
85
-
86
- # paso 1: top-1 modelo
87
- model_feats_dev = model_embeddings.to(DEVICE) if DEVICE == "cuda" else model_embeddings
88
- top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
89
- modelo_full = top_model["label"]; conf_m = top_model["confidence"]
90
-
91
- partes = modelo_full.split(" ", 1)
92
- marca = partes[0] if len(partes) >= 1 else ""
93
- modelo = partes[1] if len(partes) == 2 else ""
94
-
95
- # paso 2: filtrar versiones por prefijo
96
- matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
97
- if not matches:
98
- return {
99
- "marca": marca.upper(), "modelo": modelo.title(),
100
- "version": "", "confianza_modelo": conf_m, "confianza_version": 0.0
101
- }
102
-
103
- idxs = [i for _, i in matches]
104
- labels_sub = [lab for lab, _ in matches]
105
- embeds_sub = version_embeddings[idxs].to(DEVICE) if DEVICE == "cuda" else version_embeddings[idxs]
106
-
107
- # paso 3: top-1 version
108
- top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
109
- raw = top_ver["label"]; conf_v = top_ver["confidence"]
110
-
111
- prefix = modelo_full + " "
112
- ver = raw[len(prefix):] if raw.startswith(prefix) else raw
113
- ver = ver.split(" ")[0]
114
- if conf_v < 25.0:
115
- ver = "Version no identificada con suficiente confianza"
116
-
117
- return {
118
- "marca": marca.upper(),
119
- "modelo": modelo.title(),
120
- "version": ver.title() if ver else "",
121
- "confianza_modelo": conf_m,
122
- "confianza_version": conf_v
123
- }
124
-
125
- # ============== endpoints ==============
126
- @app.get("/")
127
- def root():
128
- return {"status": "ok", "device": DEVICE}
129
-
130
- @app.post("/predict")
131
- async def predict(front: UploadFile = File(...), back: Optional[UploadFile] = File(None)):
132
- try:
133
- front_bytes = await front.read()
134
- if back is not None:
135
- _ = await back.read()
136
- result = process_image_bytes(front_bytes)
137
- return JSONResponse(content={"code": 200, "data": result})
138
- except Exception:
139
- return JSONResponse(content={"code": 404, "data": {}}, status_code=200)
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # comentarios sin tildes / sin enye
3
+
4
+ import os, io
5
+ from typing import Optional
6
+ import torch
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from fastapi.responses import JSONResponse
9
+ from PIL import Image, UnidentifiedImageError
10
+ import open_clip
11
+ import os
12
+ os.environ.setdefault("HF_HOME", "/app/cache")
13
+ os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
14
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
15
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/app/cache/huggingface")
16
+ os.environ.setdefault("TORCH_HOME", "/app/cache/torch")
17
+ os.makedirs("/app/cache", exist_ok=True)
18
+
19
+ from torchvision import transforms as T
20
+
21
+ # limites basicos
22
+ torch.set_num_threads(1)
23
+ os.environ["OMP_NUM_THREADS"] = "1"
24
+ os.environ["MKL_NUM_THREADS"] = "1"
25
+
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
+
29
+ # nombres de ficheros (en el mismo repo)
30
+ MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
31
+ VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
32
+
33
+ app = FastAPI(title="CLIP H14 Vehicle API")
34
+
35
+ # ============== modelo CLIP ==============
36
+ clip_model, _, preprocess = open_clip.create_model_and_transforms(
37
+ "ViT-H-14", pretrained="laion2b_s32b_b79k"
38
+ )
39
+ clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
40
+ for p in clip_model.parameters():
41
+ p.requires_grad = False
42
+
43
+ normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize))
44
+ transform = T.Compose([
45
+ T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
46
+ T.ToTensor(),
47
+ T.Normalize(mean=normalize.mean, std=normalize.std),
48
+ ])
49
+
50
+ # ============== embeddings ==============
51
+ def _ensure_label_list(x):
52
+ if isinstance(x, (list, tuple)):
53
+ return list(x)
54
+ if hasattr(x, "tolist"):
55
+ return [str(s) for s in x.tolist()]
56
+ return [str(s) for s in x]
57
+
58
+ def _load_embeddings(path: str):
59
+ ckpt = torch.load(path, map_location="cpu")
60
+ labels = _ensure_label_list(ckpt["labels"])
61
+ embeds = ckpt["embeddings"].to("cpu", dtype=torch.float16)
62
+ embeds = embeds / embeds.norm(dim=-1, keepdim=True)
63
+ return labels, embeds
64
+
65
+ model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH) # "Marca Modelo"
66
+ version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH) # "Marca Modelo Version"
67
+
68
+ # ============== inferencia ==============
69
+ @torch.inference_mode()
70
+ def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
71
+ if DEVICE == "cuda":
72
+ with torch.cuda.amp.autocast(dtype=DTYPE):
73
+ feats = clip_model.encode_image(img_tensor)
74
+ else:
75
+ feats = clip_model.encode_image(img_tensor)
76
+ return feats / feats.norm(dim=-1, keepdim=True)
77
+
78
+ def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
79
+ img_f = _encode_image(image_tensor)
80
+ sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
81
+ vals, idxs = torch.topk(sim, k=topk)
82
+ return [{"label": text_labels[i], "confidence": round(float(v)*100.0, 2)} for v, i in zip(vals, idxs)]
83
+
84
+ def process_image_bytes(image_bytes: bytes):
85
+ if not image_bytes or len(image_bytes) < 128:
86
+ raise UnidentifiedImageError("imagen invalida")
87
+
88
+ img = Image.open(io.BytesIO(image_bytes))
89
+ if img.mode != "RGB":
90
+ img = img.convert("RGB")
91
+
92
+ img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
93
+
94
+ # paso 1: top-1 modelo
95
+ model_feats_dev = model_embeddings.to(DEVICE) if DEVICE == "cuda" else model_embeddings
96
+ top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
97
+ modelo_full = top_model["label"]; conf_m = top_model["confidence"]
98
+
99
+ partes = modelo_full.split(" ", 1)
100
+ marca = partes[0] if len(partes) >= 1 else ""
101
+ modelo = partes[1] if len(partes) == 2 else ""
102
+
103
+ # paso 2: filtrar versiones por prefijo
104
+ matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
105
+ if not matches:
106
+ return {
107
+ "marca": marca.upper(), "modelo": modelo.title(),
108
+ "version": "", "confianza_modelo": conf_m, "confianza_version": 0.0
109
+ }
110
+
111
+ idxs = [i for _, i in matches]
112
+ labels_sub = [lab for lab, _ in matches]
113
+ embeds_sub = version_embeddings[idxs].to(DEVICE) if DEVICE == "cuda" else version_embeddings[idxs]
114
+
115
+ # paso 3: top-1 version
116
+ top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
117
+ raw = top_ver["label"]; conf_v = top_ver["confidence"]
118
+
119
+ prefix = modelo_full + " "
120
+ ver = raw[len(prefix):] if raw.startswith(prefix) else raw
121
+ ver = ver.split(" ")[0]
122
+ if conf_v < 25.0:
123
+ ver = "Version no identificada con suficiente confianza"
124
+
125
+ return {
126
+ "marca": marca.upper(),
127
+ "modelo": modelo.title(),
128
+ "version": ver.title() if ver else "",
129
+ "confianza_modelo": conf_m,
130
+ "confianza_version": conf_v
131
+ }
132
+
133
+ # ============== endpoints ==============
134
+ @app.get("/")
135
+ def root():
136
+ return {"status": "ok", "device": DEVICE}
137
+
138
+ @app.post("/predict")
139
+ async def predict(front: UploadFile = File(...), back: Optional[UploadFile] = File(None)):
140
+ try:
141
+ front_bytes = await front.read()
142
+ if back is not None:
143
+ _ = await back.read()
144
+ result = process_image_bytes(front_bytes)
145
+ return JSONResponse(content={"code": 200, "data": result})
146
+ except Exception:
147
+ return JSONResponse(content={"code": 404, "data": {}}, status_code=200)