mountainsma commited on
Commit
5eb4524
·
1 Parent(s): bcfaacf

Fixed exif rotation + squashing

Browse files
Files changed (1) hide show
  1. handler.py +41 -7
handler.py CHANGED
@@ -6,7 +6,10 @@ from urllib.request import urlopen
6
 
7
  import open_clip
8
  import torch
9
- from PIL import Image
 
 
 
10
 
11
 
12
  def _is_git_lfs_pointer(path: Path) -> bool:
@@ -25,13 +28,14 @@ class EndpointHandler:
25
  self._validate_model_files()
26
 
27
  model_id = f"local-dir:{self.model_dir}"
28
- self.model, self.preprocess = open_clip.create_model_from_pretrained(
29
  model_id,
30
  device=self.device,
31
  return_transform=True,
32
  )
33
  self.tokenizer = open_clip.get_tokenizer(model_id)
34
  self.model.eval()
 
35
 
36
  def _validate_model_files(self) -> None:
37
  config_path = self.model_dir / "open_clip_config.json"
@@ -61,16 +65,46 @@ class EndpointHandler:
61
  "Upload the actual LFS blobs to the Hugging Face model repo before starting the endpoint."
62
  )
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def _load_image(self, image_input: Any) -> Image.Image | None:
65
  if not isinstance(image_input, str):
66
  return None
67
 
68
  if image_input.startswith(("http://", "https://")):
69
  with urlopen(image_input, timeout=10) as response:
70
- return Image.open(io.BytesIO(response.read())).convert("RGB")
 
 
 
 
 
 
71
 
72
- image_bytes = base64.b64decode(image_input.split(",")[-1])
73
- return Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
74
 
75
  def _tokenize_text(self, text: str | List[str]) -> torch.Tensor:
76
  texts = text if isinstance(text, list) else [text]
@@ -86,7 +120,7 @@ class EndpointHandler:
86
 
87
  with torch.no_grad():
88
  if image is not None and text is not None:
89
- image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
90
  text_tensor = self._tokenize_text(text)
91
 
92
  image_features = self.model.encode_image(image_tensor, normalize=True)
@@ -99,7 +133,7 @@ class EndpointHandler:
99
  response["text_embedding"] = text_features[0].cpu().tolist()
100
  return response
101
  elif image is not None:
102
- image_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
103
  image_features = self.model.encode_image(image_tensor, normalize=True)
104
  return {"image_embedding": image_features[0].cpu().tolist()}
105
  elif text is not None:
 
6
 
7
  import open_clip
8
  import torch
9
+ from PIL import Image, ImageOps
10
+ from torchvision.transforms import Compose, Normalize, ToTensor
11
+
12
+ INPUT_SIZE = 224
13
 
14
 
15
  def _is_git_lfs_pointer(path: Path) -> bool:
 
28
  self._validate_model_files()
29
 
30
  model_id = f"local-dir:{self.model_dir}"
31
+ self.model, preprocess = open_clip.create_model_from_pretrained(
32
  model_id,
33
  device=self.device,
34
  return_transform=True,
35
  )
36
  self.tokenizer = open_clip.get_tokenizer(model_id)
37
  self.model.eval()
38
+ self.tensor_preprocess = self._build_tensor_preprocess(preprocess)
39
 
40
  def _validate_model_files(self) -> None:
41
  config_path = self.model_dir / "open_clip_config.json"
 
65
  "Upload the actual LFS blobs to the Hugging Face model repo before starting the endpoint."
66
  )
67
 
68
+ @staticmethod
69
+ def _build_tensor_preprocess(original_preprocess) -> Compose:
70
+ """Extract Normalize from the model's preprocess and build ToTensor + Normalize only.
71
+
72
+ The default model preprocess includes Resize + CenterCrop + ToTensor + Normalize.
73
+ Since we manually squash images to INPUT_SIZE x INPUT_SIZE, we only need
74
+ ToTensor + Normalize to match the existing embedding pipeline.
75
+ """
76
+ normalize = None
77
+ for t in original_preprocess.transforms:
78
+ if isinstance(t, Normalize):
79
+ normalize = t
80
+ break
81
+ if normalize is None:
82
+ normalize = Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
83
+ return Compose([ToTensor(), normalize])
84
+
85
+ @staticmethod
86
+ def _prepare_image(img: Image.Image) -> Image.Image:
87
+ """Squash image to INPUT_SIZE x INPUT_SIZE."""
88
+ return img.resize((INPUT_SIZE, INPUT_SIZE), Image.BICUBIC)
89
+
90
  def _load_image(self, image_input: Any) -> Image.Image | None:
91
  if not isinstance(image_input, str):
92
  return None
93
 
94
  if image_input.startswith(("http://", "https://")):
95
  with urlopen(image_input, timeout=10) as response:
96
+ img = Image.open(io.BytesIO(response.read()))
97
+ else:
98
+ image_bytes = base64.b64decode(image_input.split(",")[-1])
99
+ img = Image.open(io.BytesIO(image_bytes))
100
+
101
+ img = ImageOps.exif_transpose(img)
102
+ return img.convert("RGB")
103
 
104
+ def _preprocess_image(self, image: Image.Image) -> torch.Tensor:
105
+ """Squash to INPUT_SIZE and apply tensor normalization."""
106
+ image = self._prepare_image(image)
107
+ return self.tensor_preprocess(image).unsqueeze(0).to(self.device)
108
 
109
  def _tokenize_text(self, text: str | List[str]) -> torch.Tensor:
110
  texts = text if isinstance(text, list) else [text]
 
120
 
121
  with torch.no_grad():
122
  if image is not None and text is not None:
123
+ image_tensor = self._preprocess_image(image)
124
  text_tensor = self._tokenize_text(text)
125
 
126
  image_features = self.model.encode_image(image_tensor, normalize=True)
 
133
  response["text_embedding"] = text_features[0].cpu().tolist()
134
  return response
135
  elif image is not None:
136
+ image_tensor = self._preprocess_image(image)
137
  image_features = self.model.encode_image(image_tensor, normalize=True)
138
  return {"image_embedding": image_features[0].cpu().tolist()}
139
  elif text is not None: