sam-brause commited on
Commit ·
edd3afe
1
Parent(s): 85c4bdf
Updated handler with EndpointHandler class
Browse files- handler.py +22 -14
handler.py
CHANGED
|
@@ -5,7 +5,7 @@ import torchvision.transforms as transforms
|
|
| 5 |
from PIL import Image
|
| 6 |
import json
|
| 7 |
|
| 8 |
-
# Hugging Face
|
| 9 |
HF_REPO = "your-username/your-model-name"
|
| 10 |
MODEL_FILENAME = "model_scripted_efficientnet.pt"
|
| 11 |
MODEL_PATH = f"/repository/{MODEL_FILENAME}"
|
|
@@ -54,20 +54,28 @@ supported_issues = [
|
|
| 54 |
"Brow Asymmetry"
|
| 55 |
]
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
|
| 67 |
-
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
return json.dumps(results)
|
|
|
|
| 5 |
from PIL import Image
|
| 6 |
import json
|
| 7 |
|
| 8 |
+
# Hugging Face Model Information
|
| 9 |
HF_REPO = "your-username/your-model-name"
|
| 10 |
MODEL_FILENAME = "model_scripted_efficientnet.pt"
|
| 11 |
MODEL_PATH = f"/repository/{MODEL_FILENAME}"
|
|
|
|
| 54 |
"Brow Asymmetry"
|
| 55 |
]
|
| 56 |
|
| 57 |
+
class EndpointHandler:
|
| 58 |
+
def __init__(self, model_dir=None):
|
| 59 |
+
"""Initialize the inference model."""
|
| 60 |
+
self.model = model
|
| 61 |
+
self.device = device
|
| 62 |
+
self.model.to(self.device)
|
| 63 |
+
self.model.eval()
|
| 64 |
+
print("Model loaded successfully.")
|
| 65 |
|
| 66 |
+
def __call__(self, data):
|
| 67 |
+
"""Perform inference on an image."""
|
| 68 |
+
if "image" not in data:
|
| 69 |
+
return {"error": "No image provided"}
|
| 70 |
|
| 71 |
+
image_data = data["image"]
|
| 72 |
+
image = Image.open(image_data).convert("RGB")
|
| 73 |
+
image = transform(image).unsqueeze(0).to(self.device)
|
| 74 |
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
outputs = self.model(image)
|
| 77 |
|
| 78 |
+
predictions = outputs.squeeze().tolist()
|
| 79 |
+
output_labels = [label for label, prob in zip(supported_issues, predictions) if prob > 0.5]
|
| 80 |
+
|
| 81 |
+
return {"predictions": output_labels}
|
|
|