Drazcat-AI commited on
Commit
7105a0a
·
verified ·
1 Parent(s): d601c8d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -4
handler.py CHANGED
@@ -13,10 +13,10 @@ class EndpointHandler():
13
  self.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
14
 
15
 
16
- def predict_image(self, url, prompt):
17
  image = Image.open(requests.get(url, stream=True).raw)
18
 
19
- inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, self.torch_dtype)
20
 
21
  generated_ids = self.model.generate(
22
  input_ids=inputs["input_ids"],
@@ -27,7 +27,7 @@ class EndpointHandler():
27
  )
28
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
29
 
30
- parsed_answer = self.processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
31
  return parsed_answer
32
 
33
 
@@ -40,8 +40,9 @@ class EndpointHandler():
40
 
41
  inputs = event["inputs"]
42
  url = inputs["url"]
 
43
  prompt = inputs["prompt"]
44
- parsed_answer = self.predict_image(url, prompt)
45
 
46
  return {
47
  "statusCode": 200,
 
13
  self.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
14
 
15
 
16
+ def predict_image(self, url, task, prompt):
17
  image = Image.open(requests.get(url, stream=True).raw)
18
 
19
+ inputs = self.processor(text=task + prompt, images=image, return_tensors="pt").to(self.device, self.torch_dtype)
20
 
21
  generated_ids = self.model.generate(
22
  input_ids=inputs["input_ids"],
 
27
  )
28
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
29
 
30
+ parsed_answer = self.processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
31
  return parsed_answer
32
 
33
 
 
40
 
41
  inputs = event["inputs"]
42
  url = inputs["url"]
43
+ task = inputs["task"]
44
  prompt = inputs["prompt"]
45
+ parsed_answer = self.predict_image(url, task, prompt)
46
 
47
  return {
48
  "statusCode": 200,