SebastianRuff commited on
Commit
b455b18
·
verified ·
1 Parent(s): 07aa66c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -67
handler.py CHANGED
@@ -2,88 +2,38 @@ from transformers import AutoModelForCausalLM, AutoProcessor
2
  from PIL import Image
3
  import requests
4
  import torch
5
- from urllib3.util.retry import Retry
6
- from requests.adapters import HTTPAdapter
7
 
8
  class EndpointHandler:
9
  def __init__(self, model_dir):
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- self.model = AutoModelForCausalLM.from_pretrained(
12
- model_dir,
13
- trust_remote_code=True
14
- ).eval().to(device)
15
- self.processor = AutoProcessor.from_pretrained(
16
- model_dir,
17
- trust_remote_code=True
18
- )
19
  self.device = device
20
- self.session = self._create_session()
21
-
22
- def _create_session(self):
23
- session = requests.Session()
24
- retries = Retry(
25
- total=3,
26
- backoff_factor=0.5,
27
- status_forcelist=[429, 500, 502, 503, 504]
28
- )
29
- session.mount('http://', HTTPAdapter(max_retries=retries))
30
- session.mount('https://', HTTPAdapter(max_retries=retries))
31
- return session
32
 
33
  def __call__(self, data):
34
  try:
35
- inputs_data = data.get("inputs", {})
36
- params = data.get("parameters", {})
 
37
 
38
- image_url = inputs_data.get("url")
39
- if not image_url:
40
- return {"error": "Missing URL in inputs"}
41
-
42
- image = self.load_image(image_url)
43
- if not image:
44
- return {"error": "Failed to load image"}
45
-
46
- model_inputs = self.processor(
47
- text=inputs_data.get("task_prompt", "<MORE_DETAILED_CAPTION>"),
48
  images=image,
49
  return_tensors="pt"
50
  ).to(self.device)
51
 
52
  with torch.inference_mode():
53
- generated_ids = self.model.generate(
54
- input_ids=model_inputs["input_ids"],
55
- pixel_values=model_inputs["pixel_values"],
56
- max_new_tokens=params.get("max_new_tokens", 512),
57
- num_beams=params.get("num_beams", 3),
58
- early_stopping=params.get("early_stopping", True),
59
- do_sample=params.get("do_sample", False)
60
  )
61
 
62
- generated_text = self.processor.batch_decode(
63
- generated_ids,
64
- skip_special_tokens=True
65
- )[0]
66
-
67
- return {"caption": generated_text}
68
 
69
  except Exception as e:
70
- return {"error": f"Processing error: {str(e)}"}
71
-
72
- def load_image(self, image_url):
73
- try:
74
- headers = {
75
- "User-Agent": "Mozilla/5.0",
76
- "Accept": "image/jpeg,image/png,image/*",
77
- "Referer": image_url
78
- }
79
- response = self.session.get(
80
- image_url,
81
- stream=True,
82
- headers=headers,
83
- timeout=15,
84
- verify=False # Added for SSL issues
85
- )
86
- response.raise_for_status()
87
- return Image.open(response.raw).convert("RGB")
88
- except Exception as e:
89
- return None
 
2
  from PIL import Image
3
  import requests
4
  import torch
 
 
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir):
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ self.model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to(device)
10
+ self.processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
 
 
 
 
 
 
11
  self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def __call__(self, data):
14
  try:
15
+ url = data.get("inputs", {}).get("url")
16
+ if not url:
17
+ return {"error": "Missing URL"}
18
 
19
+ response = requests.get(url, verify=False)
20
+ image = Image.open(response.raw).convert("RGB")
21
+
22
+ inputs = self.processor(
23
+ text="<MORE_DETAILED_CAPTION>",
 
 
 
 
 
24
  images=image,
25
  return_tensors="pt"
26
  ).to(self.device)
27
 
28
  with torch.inference_mode():
29
+ output = self.model.generate(
30
+ **inputs,
31
+ max_new_tokens=512,
32
+ num_beams=3
 
 
 
33
  )
34
 
35
+ text = self.processor.batch_decode(output, skip_special_tokens=True)[0]
36
+ return {"caption": text}
 
 
 
 
37
 
38
  except Exception as e:
39
+ return {"error": str(e)}