Update handler.py
Browse files- handler.py +5 -4
handler.py
CHANGED
|
@@ -13,10 +13,10 @@ class EndpointHandler():
|
|
| 13 |
self.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
| 14 |
|
| 15 |
|
| 16 |
-
def predict_image(self, url, prompt):
|
| 17 |
image = Image.open(requests.get(url, stream=True).raw)
|
| 18 |
|
| 19 |
-
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, self.torch_dtype)
|
| 20 |
|
| 21 |
generated_ids = self.model.generate(
|
| 22 |
input_ids=inputs["input_ids"],
|
|
@@ -27,7 +27,7 @@ class EndpointHandler():
|
|
| 27 |
)
|
| 28 |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 29 |
|
| 30 |
-
parsed_answer = self.processor.post_process_generation(generated_text, task=
|
| 31 |
return parsed_answer
|
| 32 |
|
| 33 |
|
|
@@ -40,8 +40,9 @@ class EndpointHandler():
|
|
| 40 |
|
| 41 |
inputs = event["inputs"]
|
| 42 |
url = inputs["url"]
|
|
|
|
| 43 |
prompt = inputs["prompt"]
|
| 44 |
-
parsed_answer = self.predict_image(url, prompt)
|
| 45 |
|
| 46 |
return {
|
| 47 |
"statusCode": 200,
|
|
|
|
| 13 |
self.processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
| 14 |
|
| 15 |
|
| 16 |
+
def predict_image(self, url, task, prompt):
|
| 17 |
image = Image.open(requests.get(url, stream=True).raw)
|
| 18 |
|
| 19 |
+
inputs = self.processor(text=task + prompt, images=image, return_tensors="pt").to(self.device, self.torch_dtype)
|
| 20 |
|
| 21 |
generated_ids = self.model.generate(
|
| 22 |
input_ids=inputs["input_ids"],
|
|
|
|
| 27 |
)
|
| 28 |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 29 |
|
| 30 |
+
parsed_answer = self.processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
|
| 31 |
return parsed_answer
|
| 32 |
|
| 33 |
|
|
|
|
| 40 |
|
| 41 |
inputs = event["inputs"]
|
| 42 |
url = inputs["url"]
|
| 43 |
+
task = inputs["task"]
|
| 44 |
prompt = inputs["prompt"]
|
| 45 |
+
parsed_answer = self.predict_image(url, task, prompt)
|
| 46 |
|
| 47 |
return {
|
| 48 |
"statusCode": 200,
|