dazpye commited on
Commit
1549f24
Β·
verified Β·
1 Parent(s): 7ed3506

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -11
handler.py CHANGED
@@ -5,40 +5,60 @@ import requests
5
  import io
6
 
7
  class EndpointHandler:
8
- def __init__(self, model_dir=None): # AWS expects model_dir
9
  print("πŸ”„ Loading model...")
10
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
11
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
12
 
13
  def _load_image(self, image_url):
14
- """Simple image loader for URL images."""
15
  try:
16
- print(f"🌐 Fetching image: {image_url}")
17
  response = requests.get(image_url, timeout=5)
18
- response.raise_for_status() # Raise error if status is not 200
19
  return Image.open(io.BytesIO(response.content)).convert("RGB")
20
  except Exception as e:
21
  print(f"❌ Image loading failed: {e}")
22
- return None # Return None if image loading fails
23
 
24
  def __call__(self, data):
25
  """Processes input and runs inference."""
26
- print("πŸ“₯ Received input...")
27
 
28
- text = data.get("inputs", {}).get("text", ["default text"])
29
- image_urls = data.get("inputs", {}).get("images", [])
 
 
 
30
 
31
  images = [self._load_image(url) for url in image_urls if url]
32
  images = [img for img in images if img] # Remove failed images
33
 
34
  if not images:
35
- print("❌ No valid images provided.")
36
  return {"error": "No valid images provided."}
37
 
38
- inputs = self.processor(text=text, images=images, return_tensors="pt")
 
 
 
 
 
 
 
39
 
40
  print("πŸ–₯️ Running inference...")
41
  with torch.no_grad():
42
  outputs = self.model(**inputs)
43
 
44
- return {"predictions": outputs.logits_per_image.softmax(dim=1).tolist()}
 
 
 
 
 
 
 
 
 
 
 
 
5
  import io
6
 
7
  class EndpointHandler:
8
+ def __init__(self, model_dir=None):
9
  print("πŸ”„ Loading model...")
10
  self.model = CLIPModel.from_pretrained("dazpye/clip-image")
11
  self.processor = CLIPProcessor.from_pretrained("dazpye/clip-image")
12
 
13
  def _load_image(self, image_url):
14
+ """Fetches an image from a URL."""
15
  try:
16
+ print(f"🌐 Fetching image from: {image_url}")
17
  response = requests.get(image_url, timeout=5)
18
+ response.raise_for_status()
19
  return Image.open(io.BytesIO(response.content)).convert("RGB")
20
  except Exception as e:
21
  print(f"❌ Image loading failed: {e}")
22
+ return None
23
 
24
  def __call__(self, data):
25
  """Processes input and runs inference."""
26
+ print("πŸ“₯ Processing input...")
27
 
28
+ if "inputs" in data:
29
+ data = data["inputs"]
30
+
31
+ text = data.get("text", ["default text"])
32
+ image_urls = data.get("images", [])
33
 
34
  images = [self._load_image(url) for url in image_urls if url]
35
  images = [img for img in images if img] # Remove failed images
36
 
37
  if not images:
 
38
  return {"error": "No valid images provided."}
39
 
40
+ # Enable padding & truncation to fix tensor error
41
+ inputs = self.processor(
42
+ text=text,
43
+ images=images,
44
+ return_tensors="pt",
45
+ padding=True,
46
+ truncation=True
47
+ )
48
 
49
  print("πŸ–₯️ Running inference...")
50
  with torch.no_grad():
51
  outputs = self.model(**inputs)
52
 
53
+ # Get scores & find best matches
54
+ logits_per_image = outputs.logits_per_image
55
+ probabilities = logits_per_image.softmax(dim=1)
56
+
57
+ # Get top categories per image
58
+ predictions = []
59
+ for i, probs in enumerate(probabilities):
60
+ sorted_indices = torch.argsort(probs, descending=True)
61
+ best_matches = [(text[idx], probs[idx].item()) for idx in sorted_indices[:3]] # Get top 3 matches
62
+ predictions.append({"image_index": i, "top_matches": best_matches})
63
+
64
+ return {"predictions": predictions}