Update handler.py
Browse files- handler.py +8 -2
handler.py
CHANGED
|
@@ -1,18 +1,24 @@
|
|
| 1 |
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 2 |
from PIL import Image
|
| 3 |
import requests
|
|
|
|
| 4 |
|
| 5 |
class EndpointHandler:
|
| 6 |
def __init__(self, model_dir):
|
|
|
|
|
|
|
|
|
|
| 7 |
# Load the model with trust_remote_code=True
|
| 8 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 9 |
model_dir,
|
| 10 |
trust_remote_code=True
|
| 11 |
-
).eval().
|
|
|
|
| 12 |
self.processor = AutoProcessor.from_pretrained(
|
| 13 |
model_dir,
|
| 14 |
trust_remote_code=True
|
| 15 |
)
|
|
|
|
| 16 |
|
| 17 |
def __call__(self, data):
|
| 18 |
# Extract inputs from the request data
|
|
@@ -27,7 +33,7 @@ class EndpointHandler:
|
|
| 27 |
text=task_prompt,
|
| 28 |
images=image,
|
| 29 |
return_tensors="pt"
|
| 30 |
-
).to(
|
| 31 |
|
| 32 |
# Generate output
|
| 33 |
generated_ids = self.model.generate(
|
|
|
|
| 1 |
from transformers import AutoModelForCausalLM, AutoProcessor
|
| 2 |
from PIL import Image
|
| 3 |
import requests
|
| 4 |
+
import torch
|
| 5 |
|
| 6 |
class EndpointHandler:
|
| 7 |
def __init__(self, model_dir):
|
| 8 |
+
# Check if a GPU is available; use CPU if not
|
| 9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
|
| 11 |
# Load the model with trust_remote_code=True
|
| 12 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
model_dir,
|
| 14 |
trust_remote_code=True
|
| 15 |
+
).eval().to(device) # Dynamically move to the correct device
|
| 16 |
+
|
| 17 |
self.processor = AutoProcessor.from_pretrained(
|
| 18 |
model_dir,
|
| 19 |
trust_remote_code=True
|
| 20 |
)
|
| 21 |
+
self.device = device
|
| 22 |
|
| 23 |
def __call__(self, data):
|
| 24 |
# Extract inputs from the request data
|
|
|
|
| 33 |
text=task_prompt,
|
| 34 |
images=image,
|
| 35 |
return_tensors="pt"
|
| 36 |
+
).to(self.device) # Use the correct device
|
| 37 |
|
| 38 |
# Generate output
|
| 39 |
generated_ids = self.model.generate(
|