SebastianRuff commited on
Commit
07aa66c
·
verified ·
1 Parent(s): 68a8f79

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +42 -13
handler.py CHANGED
@@ -2,6 +2,8 @@ from transformers import AutoModelForCausalLM, AutoProcessor
2
  from PIL import Image
3
  import requests
4
  import torch
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
@@ -15,6 +17,18 @@ class EndpointHandler:
15
  trust_remote_code=True
16
  )
17
  self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def __call__(self, data):
20
  try:
@@ -22,24 +36,28 @@ class EndpointHandler:
22
  params = data.get("parameters", {})
23
 
24
  image_url = inputs_data.get("url")
25
- if not image_url or not image_url.startswith("http"):
26
- raise ValueError("Invalid or missing 'url' field")
27
 
28
  image = self.load_image(image_url)
 
 
 
29
  model_inputs = self.processor(
30
  text=inputs_data.get("task_prompt", "<MORE_DETAILED_CAPTION>"),
31
  images=image,
32
  return_tensors="pt"
33
  ).to(self.device)
34
 
35
- generated_ids = self.model.generate(
36
- input_ids=model_inputs["input_ids"],
37
- pixel_values=model_inputs["pixel_values"],
38
- max_new_tokens=params.get("max_new_tokens", 512),
39
- num_beams=params.get("num_beams", 3),
40
- early_stopping=params.get("early_stopping", True),
41
- do_sample=params.get("do_sample", False)
42
- )
 
43
 
44
  generated_text = self.processor.batch_decode(
45
  generated_ids,
@@ -49,12 +67,23 @@ class EndpointHandler:
49
  return {"caption": generated_text}
50
 
51
  except Exception as e:
52
- return {"error": str(e)}
53
 
54
  def load_image(self, image_url):
55
  try:
56
- response = requests.get(image_url, stream=True)
 
 
 
 
 
 
 
 
 
 
 
57
  response.raise_for_status()
58
  return Image.open(response.raw).convert("RGB")
59
  except Exception as e:
60
- raise ValueError(f"Failed to load image: {str(e)}")
 
2
  from PIL import Image
3
  import requests
4
  import torch
5
+ from urllib3.util.retry import Retry
6
+ from requests.adapters import HTTPAdapter
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir):
 
17
  trust_remote_code=True
18
  )
19
  self.device = device
20
+ self.session = self._create_session()
21
+
22
+ def _create_session(self):
23
+ session = requests.Session()
24
+ retries = Retry(
25
+ total=3,
26
+ backoff_factor=0.5,
27
+ status_forcelist=[429, 500, 502, 503, 504]
28
+ )
29
+ session.mount('http://', HTTPAdapter(max_retries=retries))
30
+ session.mount('https://', HTTPAdapter(max_retries=retries))
31
+ return session
32
 
33
  def __call__(self, data):
34
  try:
 
36
  params = data.get("parameters", {})
37
 
38
  image_url = inputs_data.get("url")
39
+ if not image_url:
40
+ return {"error": "Missing URL in inputs"}
41
 
42
  image = self.load_image(image_url)
43
+ if not image:
44
+ return {"error": "Failed to load image"}
45
+
46
  model_inputs = self.processor(
47
  text=inputs_data.get("task_prompt", "<MORE_DETAILED_CAPTION>"),
48
  images=image,
49
  return_tensors="pt"
50
  ).to(self.device)
51
 
52
+ with torch.inference_mode():
53
+ generated_ids = self.model.generate(
54
+ input_ids=model_inputs["input_ids"],
55
+ pixel_values=model_inputs["pixel_values"],
56
+ max_new_tokens=params.get("max_new_tokens", 512),
57
+ num_beams=params.get("num_beams", 3),
58
+ early_stopping=params.get("early_stopping", True),
59
+ do_sample=params.get("do_sample", False)
60
+ )
61
 
62
  generated_text = self.processor.batch_decode(
63
  generated_ids,
 
67
  return {"caption": generated_text}
68
 
69
  except Exception as e:
70
+ return {"error": f"Processing error: {str(e)}"}
71
 
72
  def load_image(self, image_url):
73
  try:
74
+ headers = {
75
+ "User-Agent": "Mozilla/5.0",
76
+ "Accept": "image/jpeg,image/png,image/*",
77
+ "Referer": image_url
78
+ }
79
+ response = self.session.get(
80
+ image_url,
81
+ stream=True,
82
+ headers=headers,
83
+ timeout=15,
84
+ verify=False # Added for SSL issues
85
+ )
86
  response.raise_for_status()
87
  return Image.open(response.raw).convert("RGB")
88
  except Exception as e:
89
+ return None