addgbf commited on
Commit
34e0af9
·
verified ·
1 Parent(s): 95e8893

Update server1.py

Browse files
Files changed (1) hide show
  1. server1.py +27 -34
server1.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  from fastapi import FastAPI, File, UploadFile, Request
8
  from fastapi.responses import JSONResponse
9
  from PIL import Image, UnidentifiedImageError, ImageFile
10
- from torchvision import transforms as T
11
 
12
  ImageFile.LOAD_TRUNCATED_IMAGES = True
13
 
@@ -25,10 +25,12 @@ os.makedirs(os.environ["TORCH_HOME"], exist_ok=True)
25
 
26
  import open_clip # importar despues de ajustar caches
27
 
28
- # ===== limites basicos =====
29
- torch.set_num_threads(1)
30
- os.environ["OMP_NUM_THREADS"] = "1"
31
- os.environ["MKL_NUM_THREADS"] = "1"
 
 
32
 
33
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
@@ -40,7 +42,6 @@ MODEL_EMB_PATH = os.getenv("MODEL_EMB_PATH", "text_embeddings_modelos_bigg.pt")
40
  VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt")
41
 
42
  # ===== modelo PE bigG =====
43
- # usamos HF Hub; no pasamos 'pretrained' clasico
44
  MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448"
45
  PRETRAINED = None
46
 
@@ -52,36 +53,22 @@ _ret = open_clip.create_model_and_transforms(MODEL_NAME, pretrained=PRETRAINED)
52
  if isinstance(_ret, tuple) and len(_ret) == 3:
53
  clip_model, _preprocess_train, preprocess = _ret
54
  else:
55
- # fallback por si alguna version devuelve solo 2
56
  clip_model, preprocess = _ret
57
 
58
  clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
59
  for p in clip_model.parameters():
60
  p.requires_grad = False
61
 
62
- # extraer normalizacion y size desde el preprocess devuelto
63
- normalize = next(t for t in getattr(preprocess, "transforms", []) if isinstance(t, T.Normalize))
 
 
 
64
  SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
65
  if isinstance(SIZE, (tuple, list)):
66
- SIZE = max(SIZE) # por si viene como (H,W)
67
  if SIZE is None:
68
- SIZE = 448 # PE bigG es 448; fallback por seguridad
69
-
70
- transform = T.Compose([T.ToTensor(), T.Normalize(mean=normalize.mean, std=normalize.std)])
71
-
72
- # ===== utils imagen =====
73
- def resize_letterbox(img: Image.Image, size: int) -> Image.Image:
74
- if img.mode != "RGB":
75
- img = img.convert("RGB")
76
- w, h = img.size
77
- if w == 0 or h == 0:
78
- raise UnidentifiedImageError("imagen invalida")
79
- scale = size / max(w, h)
80
- nw, nh = max(1, int(w*scale)), max(1, int(h*scale))
81
- img_resized = img.resize((nw, nh), Image.BICUBIC)
82
- canvas = Image.new("RGB", (size, size), (0, 0, 0))
83
- canvas.paste(img_resized, ((size-nw)//2, (size-nh)//2))
84
- return canvas
85
 
86
  # ===== cargar embeddings =====
87
  def _ensure_label_list(x):
@@ -103,8 +90,12 @@ version_labels, version_embeddings = _load_embeddings(VERS_EMB_PATH)
103
 
104
  # comprobar dimension (PE bigG mantiene 1280)
105
  with torch.inference_mode():
106
- dummy = torch.zeros(1, 3, SIZE, SIZE, device=DEVICE, dtype=DTYPE)
107
- img_dim = clip_model.encode_image(dummy).shape[-1]
 
 
 
 
108
  if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim:
109
  raise RuntimeError(
110
  f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, "
@@ -114,15 +105,17 @@ if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_di
114
  # ===== inferencia =====
115
  @torch.inference_mode()
116
  def _encode_pil(img: Image.Image) -> torch.Tensor:
117
- img = resize_letterbox(img, SIZE)
118
- tensor = transform(img).unsqueeze(0).to(device=DEVICE)
 
119
  if DEVICE == "cuda":
120
- tensor = tensor.to(dtype=DTYPE)
121
- feats = clip_model.encode_image(tensor)
122
  return feats / feats.norm(dim=-1, keepdim=True)
123
 
124
  def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1):
125
- sim = (img_feat.float() @ text_feats.to(img_feat.device).float().T)[0]
 
126
  vals, idxs = torch.topk(sim, k=k)
127
  conf = torch.softmax(vals, dim=0)
128
  return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)]
 
7
  from fastapi import FastAPI, File, UploadFile, Request
8
  from fastapi.responses import JSONResponse
9
  from PIL import Image, UnidentifiedImageError, ImageFile
10
+ import multiprocessing as mp
11
 
12
  ImageFile.LOAD_TRUNCATED_IMAGES = True
13
 
 
25
 
26
  import open_clip # importar despues de ajustar caches
27
 
28
+ # ===== limites basicos (usar todos los nucleos) =====
29
+ NTHREADS = max(1, mp.cpu_count())
30
+ torch.set_num_threads(NTHREADS)
31
+ os.environ["OMP_NUM_THREADS"] = str(NTHREADS)
32
+ os.environ["MKL_NUM_THREADS"] = str(NTHREADS)
33
+ # opcional: torch.set_num_interop_threads(1)
34
 
35
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
36
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
 
42
  VERS_EMB_PATH = os.getenv("VERS_EMB_PATH", "text_embeddings_bigg.pt")
43
 
44
  # ===== modelo PE bigG =====
 
45
  MODEL_NAME = "hf-hub:timm/PE-Core-bigG-14-448"
46
  PRETRAINED = None
47
 
 
53
  if isinstance(_ret, tuple) and len(_ret) == 3:
54
  clip_model, _preprocess_train, preprocess = _ret
55
  else:
 
56
  clip_model, preprocess = _ret
57
 
58
  clip_model = clip_model.to(device=DEVICE, dtype=DTYPE).eval()
59
  for p in clip_model.parameters():
60
  p.requires_grad = False
61
 
62
+ # opcional GPU: formato canales para mejorar rendimiento
63
+ if DEVICE == "cuda":
64
+ clip_model = clip_model.to(memory_format=torch.channels_last)
65
+
66
+ # obtener SIZE desde preprocess solo para chequeos
67
  SIZE = next((getattr(t, "size", None) for t in getattr(preprocess, "transforms", []) if hasattr(t, "size")), None)
68
  if isinstance(SIZE, (tuple, list)):
69
+ SIZE = max(SIZE)
70
  if SIZE is None:
71
+ SIZE = 448 # fallback por seguridad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # ===== cargar embeddings =====
74
  def _ensure_label_list(x):
 
90
 
91
  # comprobar dimension (PE bigG mantiene 1280)
92
  with torch.inference_mode():
93
+ dummy = Image.new("RGB", (SIZE, SIZE), (0, 0, 0))
94
+ tensor = preprocess(dummy).unsqueeze(0).to(device=DEVICE)
95
+ if DEVICE == "cuda":
96
+ tensor = tensor.to(dtype=DTYPE)
97
+ img_dim = clip_model.encode_image(tensor).shape[-1]
98
+
99
  if model_embeddings.shape[1] != img_dim or version_embeddings.shape[1] != img_dim:
100
  raise RuntimeError(
101
  f"dimension mismatch: image={img_dim}, modelos={model_embeddings.shape[1]}, "
 
105
  # ===== inferencia =====
106
  @torch.inference_mode()
107
  def _encode_pil(img: Image.Image) -> torch.Tensor:
108
+ if img.mode != "RGB":
109
+ img = img.convert("RGB")
110
+ x = preprocess(img).unsqueeze(0).to(device=DEVICE)
111
  if DEVICE == "cuda":
112
+ x = x.to(dtype=DTYPE)
113
+ feats = clip_model.encode_image(x)
114
  return feats / feats.norm(dim=-1, keepdim=True)
115
 
116
  def _topk_cosine(text_feats: torch.Tensor, text_labels: List[str], img_feat: torch.Tensor, k: int = 1):
117
+ tf = text_feats.to(img_feat.device, non_blocking=True)
118
+ sim = (img_feat @ tf.T)[0] # tensores ya normalizados
119
  vals, idxs = torch.topk(sim, k=k)
120
  conf = torch.softmax(vals, dim=0)
121
  return [{"label": text_labels[int(i)], "confidence": round(float(c)*100.0, 2)} for i, c in zip(idxs, conf)]