Update server1.py

#1
by addgbf - opened
Files changed (1) hide show
  1. server1.py +67 -157
server1.py CHANGED
@@ -1,173 +1,83 @@
1
- # app.py
2
  # comentarios sin tildes / sin enye
3
 
4
- import os, io, traceback
5
- from typing import Optional
6
- import torch
7
- from fastapi import FastAPI, File, UploadFile, Request
8
  from fastapi.responses import JSONResponse
9
- from PIL import Image, UnidentifiedImageError
10
- import open_clip
11
- from torchvision import transforms as T
12
-
13
- # caches locales (evitar permisos en /)
14
- os.environ.setdefault("HF_HOME", "/app/cache")
15
- os.environ.setdefault("XDG_CACHE_HOME", "/app/cache")
16
- os.environ.setdefault("HUGGINGFACE_HUB_CACHE", "/app/cache/huggingface")
17
- os.environ.setdefault("TRANSFORMERS_CACHE", "/app/cache/huggingface")
18
- os.environ.setdefault("TORCH_HOME", "/app/cache/torch")
19
- os.makedirs("/app/cache", exist_ok=True)
20
-
21
- # limites basicos
22
- torch.set_num_threads(1)
23
- os.environ["OMP_NUM_THREADS"] = "1"
24
- os.environ["MKL_NUM_THREADS"] = "1"
25
 
26
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
- DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
28
-
29
- # rutas a embeddings
30
- MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_h14.pt")
31
- VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_h14.pt")
32
-
33
- app = FastAPI(title="CLIP H14 Vehicle API")
34
-
35
- # ============== modelo CLIP ==============
36
- clip_model, _, preprocess = open_clip.create_model_and_transforms(
37
- "ViT-H-14", pretrained="laion2b_s32b_b79k"
38
- )
39
- clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
40
- for p in clip_model.parameters():
41
- p.requires_grad = False
42
-
43
- normalize = next(t for t in preprocess.transforms if isinstance(t, T.Normalize))
44
- transform = T.Compose([
45
- T.Resize((224, 224), interpolation=T.InterpolationMode.BICUBIC),
46
- T.ToTensor(),
47
- T.Normalize(mean=normalize.mean, std=normalize.std),
48
- ])
49
-
50
- # ============== embeddings ==============
51
- def _ensure_label_list(x):
52
- if isinstance(x, (list, tuple)):
53
- return list(x)
54
- if hasattr(x, "tolist"):
55
- return [str(s) for s in x.tolist()]
56
- return [str(s) for s in x]
57
-
58
- def _load_embeddings(path: str):
59
- ckpt = torch.load(path, map_location="cpu")
60
- labels = _ensure_label_list(ckpt["labels"])
61
- embeds = ckpt["embeddings"].to("cpu") # guardados como fp16; los castearemos mas tarde
62
- embeds = embeds / embeds.norm(dim=-1, keepdim=True)
63
- return labels, embeds
64
-
65
- model_labels, model_embeddings = _load_embeddings(MODEL_EMB_PATH)
66
- version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
67
-
68
- # ============== inferencia ==============
69
- @torch.inference_mode()
70
- def _encode_image(img_tensor: torch.Tensor) -> torch.Tensor:
71
- if DEVICE == "cuda":
72
- with torch.cuda.amp.autocast(dtype=DTYPE):
73
- feats = clip_model.encode_image(img_tensor)
74
- else:
75
- feats = clip_model.encode_image(img_tensor)
76
- return feats / feats.norm(dim=-1, keepdim=True)
77
-
78
- def _predict_top(text_feats_dev: torch.Tensor, text_labels: list[str], image_tensor: torch.Tensor, topk: int = 1):
79
- img_f = _encode_image(image_tensor)
80
- # asegurar mismo device y dtype
81
- text_feats_dev = text_feats_dev.to(device=img_f.device, dtype=img_f.dtype)
82
- sim = (100.0 * img_f @ text_feats_dev.T).softmax(dim=-1)[0]
83
- vals, idxs = torch.topk(sim, k=topk)
84
- return [{"label": text_labels[i], "confidence": round(float(v) * 100.0, 2)} for v, i in zip(vals, idxs)]
85
-
86
- def process_image_bytes(image_bytes: bytes):
87
- # devuelve solo el dict vehicle: brand/model/version
88
- if not image_bytes or len(image_bytes) < 128:
89
- raise UnidentifiedImageError("imagen invalida")
90
 
91
  img = Image.open(io.BytesIO(image_bytes))
92
  if img.mode != "RGB":
93
  img = img.convert("RGB")
94
 
95
- img_tensor = transform(img).unsqueeze(0).to(device=DEVICE, dtype=DTYPE)
96
-
97
- # paso 1: top-1 modelo
98
- model_feats_dev = model_embeddings.to(device=DEVICE, dtype=DTYPE)
99
- top_model = _predict_top(model_feats_dev, model_labels, img_tensor, topk=1)[0]
100
- modelo_full = top_model["label"]
101
-
102
- partes = modelo_full.split(" ", 1)
103
- marca = partes[0] if len(partes) >= 1 else ""
104
- modelo = partes[1] if len(partes) == 2 else ""
105
-
106
- # paso 2: filtrar versiones por prefijo
107
- matches = [(lab, idx) for idx, lab in enumerate(version_labels) if lab.startswith(modelo_full)]
108
- if not matches:
109
- return {
110
- "brand": marca.upper(),
111
- "model": modelo.title(),
112
- "version": ""
113
- }
114
-
115
- idxs = [i for _, i in matches]
116
- labels_sub = [lab for lab, _ in matches]
117
- embeds_sub = version_embeddings[idxs].to(device=DEVICE, dtype=DTYPE)
118
-
119
- # paso 3: top-1 version
120
- top_ver = _predict_top(embeds_sub, labels_sub, img_tensor, topk=1)[0]
121
- raw = top_ver["label"]
122
-
123
- prefix = modelo_full + " "
124
- ver = raw[len(prefix):] if raw.startswith(prefix) else raw
125
- ver = ver.split(" ")[0]
126
-
127
- # si baja confianza, no rellenamos version
128
- if top_ver["confidence"] < 25.0:
129
- ver = ""
130
-
131
- return {
132
- "brand": marca.upper(),
133
- "model": modelo.title(),
134
- "version": ver.title() if ver else ""
135
- }
136
-
137
- # ============== endpoints ==============
138
- @app.get("/")
139
- def root():
140
- return {"status": "ok", "device": DEVICE}
141
-
142
- @app.post("/predict/")
143
- async def predict(front: UploadFile = File(None),
144
- back: Optional[UploadFile] = File(None),
145
- request: Request = None):
146
  try:
147
- if request:
148
- print("headers:", dict(request.headers))
149
- if front is None:
150
- return JSONResponse(
151
- content={"code": 400, "error": "faltan archivos: 'front' es obligatorio"},
152
- status_code=200
153
- )
154
-
155
- front_bytes = await front.read()
156
- if back is not None:
157
- _ = await back.read()
158
-
159
- vehicle = process_image_bytes(front_bytes)
160
  return JSONResponse(
161
- content={"code": 200, "data": {"vehicle": vehicle}},
162
- status_code=200
163
  )
164
-
165
  except Exception as e:
166
- print("EXCEPTION:", repr(e))
167
- traceback.print_exc()
168
  return JSONResponse(
169
- content={"code": 404, "data": {}, "error": str(e)},
170
- status_code=200
171
  )
172
-
173
-
 
 
1
  # comentarios sin tildes / sin enye
2
 
3
+ import io, os
4
+ import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File
 
6
  from fastapi.responses import JSONResponse
7
+ from PIL import Image
8
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
9
+ import torch
10
+
11
+ app = FastAPI(title="Accudoctor Strip Analyzer")
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ # cargar modelo SAM2
16
+ sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
17
+ sam.to(device=DEVICE)
18
+ mask_generator = SamAutomaticMaskGenerator(sam)
19
+
20
+ def dominant_color(pil_img):
21
+ img = pil_img.resize((60, 60))
22
+ arr = np.array(img)
23
+ arr = arr.reshape((-1, 3))
24
+
25
+ pixels, counts = np.unique(arr, axis=0, return_counts=True)
26
+ dom = pixels[counts.argmax()]
27
+
28
+ return "#{:02x}{:02x}{:02x}".format(dom[0], dom[1], dom[2])
29
+
30
+ def analyze_strip(image_bytes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  img = Image.open(io.BytesIO(image_bytes))
33
  if img.mode != "RGB":
34
  img = img.convert("RGB")
35
 
36
+ np_img = np.array(img)
37
+
38
+ masks = mask_generator.generate(np_img)
39
+
40
+ blocks = []
41
+ H = np_img.shape[0]
42
+
43
+ for m in masks:
44
+ x, y, w, h = m["bbox"]
45
+ aspect = h / (w + 1e-6)
46
+
47
+ if aspect < 3:
48
+ continue
49
+
50
+ if h < H * 0.04:
51
+ continue
52
+
53
+ crop = img.crop((x, y, x+w, y+h))
54
+ color = dominant_color(crop)
55
+
56
+ blocks.append({
57
+ "bbox": [int(x), int(y), int(x+w), int(y+h)],
58
+ "color_hex": color,
59
+ "y_center": y + h/2
60
+ })
61
+
62
+ blocks = sorted(blocks, key=lambda b: b["y_center"])
63
+ for i,b in enumerate(blocks):
64
+ b["index"] = i+1
65
+ del b["y_center"]
66
+
67
+ return blocks[:11]
68
+
69
+ @app.post("/strip/")
70
+ async def strip(front: UploadFile = File(...)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  try:
72
+ bytes_img = await front.read()
73
+ result = analyze_strip(bytes_img)
74
+
 
 
 
 
 
 
 
 
 
 
75
  return JSONResponse(
76
+ status_code=200,
77
+ content={"code": 200, "blocks": result}
78
  )
 
79
  except Exception as e:
 
 
80
  return JSONResponse(
81
+ status_code=200,
82
+ content={"code": 500, "error": str(e)}
83
  )