supib4132 commited on
Commit
350853d
·
verified ·
1 Parent(s): c7e3382

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +19 -2
inference.py CHANGED
@@ -31,11 +31,28 @@ def extract_image_features(image):
31
  Input: PIL Image or image path (str).
32
  Output: Normalized image embedding (numpy array).
33
  """
34
- try:
35
- if isinstance(image, str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  image = Image.open(image).convert("RGB")
37
  else:
38
  image = image.convert("RGB")
 
39
  inputs = clip_processor(images=image, return_tensors="pt")
40
  with torch.no_grad():
41
  features = clip_model.get_image_features(**inputs)
 
31
  Input: PIL Image or image path (str).
32
  Output: Normalized image embedding (numpy array).
33
  """
34
+ # try:
35
+ # if isinstance(image, str):
36
+ # image = Image.open(image).convert("RGB")
37
+ # else:
38
+ # image = image.convert("RGB")
39
+ # inputs = clip_processor(images=image, return_tensors="pt")
40
+ # with torch.no_grad():
41
+ # features = clip_model.get_image_features(**inputs)
42
+ # features = torch.nn.functional.normalize(features, p=2, dim=-1)
43
+ # return features.squeeze(0).cpu().numpy().astype("float32")
44
+ # except Exception as e:
45
+ # print(f"Error extracting features: {e}")
46
+ # return None
47
+ try:
48
+ # Convert NumPy array to PIL if needed
49
+ if isinstance(image, np.ndarray):
50
+ image = Image.fromarray(image.astype("uint8")).convert("RGB")
51
+ elif isinstance(image, str):
52
  image = Image.open(image).convert("RGB")
53
  else:
54
  image = image.convert("RGB")
55
+
56
  inputs = clip_processor(images=image, return_tensors="pt")
57
  with torch.no_grad():
58
  features = clip_model.get_image_features(**inputs)